Source code for torch_geometric.explain.metric.fidelity

from typing import Tuple

import torch
from torch import Tensor

from torch_geometric.explain import Explainer, Explanation
from torch_geometric.explain.config import ExplanationType, ModelMode


[docs]def fidelity( explainer: Explainer, explanation: Explanation, ) -> Tuple[float, float]: r"""Evaluates the fidelity of an :class:`~torch_geometric.explain.Explainer` given an :class:`~torch_geometric.explain.Explanation`, as described in the `"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" <https://arxiv.org/abs/2206.09677>`_ paper. Fidelity evaluates the contribution of the produced explanatory subgraph to the initial prediction, either by giving only the subgraph to the model (fidelity-) or by removing it from the entire graph (fidelity+). The fidelity scores capture how good an explainable model reproduces the natural phenomenon or the GNN model logic. For **phenomenon** explanations, the fidelity scores are given by: .. math:: \textrm{fid}_{+} &= \frac{1}{N} \sum_{i = 1}^N \| \mathbb{1}(\hat{y}_i = y_i) - \mathbb{1}( \hat{y}_i^{G_{C \setminus S}} = y_i) \| \textrm{fid}_{-} &= \frac{1}{N} \sum_{i = 1}^N \| \mathbb{1}(\hat{y}_i = y_i) - \mathbb{1}( \hat{y}_i^{G_S} = y_i) \| For **model** explanations, the fidelity scores are given by: .. math:: \textrm{fid}_{+} &= 1 - \frac{1}{N} \sum_{i = 1}^N \mathbb{1}( \hat{y}_i^{G_{C \setminus S}} = \hat{y}_i) \textrm{fid}_{-} &= 1 - \frac{1}{N} \sum_{i = 1}^N \mathbb{1}( \hat{y}_i^{G_S} = \hat{y}_i) Args: explainer (Explainer): The explainer to evaluate. explanation (Explanation): The explanation to evaluate. """ if explainer.model_config.mode == ModelMode.regression: raise ValueError("Fidelity not defined for 'regression' models") node_mask = explanation.get('node_mask') edge_mask = explanation.get('edge_mask') kwargs = {key: explanation[key] for key in explanation._model_args} y = explanation.target if explainer.explanation_type == ExplanationType.phenomenon: y_hat = explainer.get_prediction( explanation.x, explanation.edge_index, **kwargs, ) y_hat = explainer.get_target(y_hat) explain_y_hat = explainer.get_masked_prediction( explanation.x, explanation.edge_index, node_mask, edge_mask, **kwargs, ) explain_y_hat = explainer.get_target(explain_y_hat) complement_y_hat = explainer.get_masked_prediction( explanation.x, explanation.edge_index, 1. - node_mask if node_mask is not None else None, 1. - edge_mask if edge_mask is not None else None, **kwargs, ) complement_y_hat = explainer.get_target(complement_y_hat) if explanation.get('index') is not None: y = y[explanation.index] if explainer.explanation_type == ExplanationType.phenomenon: y_hat = y_hat[explanation.index] explain_y_hat = explain_y_hat[explanation.index] complement_y_hat = complement_y_hat[explanation.index] if explainer.explanation_type == ExplanationType.model: pos_fidelity = 1. - (complement_y_hat == y).float().mean() neg_fidelity = 1. - (explain_y_hat == y).float().mean() else: pos_fidelity = ((y_hat == y).float() - (complement_y_hat == y).float()).abs().mean() neg_fidelity = ((y_hat == y).float() - (explain_y_hat == y).float()).abs().mean() return float(pos_fidelity), float(neg_fidelity)
[docs]def characterization_score( pos_fidelity: Tensor, neg_fidelity: Tensor, pos_weight: float = 0.5, neg_weight: float = 0.5, ) -> Tensor: r"""Returns the componentwise characterization score as described in the `"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" <https://arxiv.org/abs/2206.09677>`_ paper. .. math:: \textrm{charact} = \frac{w_{+} + w_{-}}{\frac{w_{+}}{\textrm{fid}_{+}} + \frac{w_{-}}{1 - \textrm{fid}_{-}}} Args: pos_fidelity (torch.Tensor): The positive fidelity :math:`\textrm{fid}_{+}`. neg_fidelity (torch.Tensor): The negative fidelity :math:`\textrm{fid}_{-}`. pos_weight (float, optional): The weight :math:`w_{+}` for :math:`\textrm{fid}_{+}`. (default: :obj:`0.5`) neg_weight (float, optional): The weight :math:`w_{-}` for :math:`\textrm{fid}_{-}`. (default: :obj:`0.5`) """ if (pos_weight + neg_weight) != 1.0: raise ValueError(f"The weights need to sum up to 1 " f"(got {pos_weight} and {neg_weight})") denom = (pos_weight / pos_fidelity) + (neg_weight / (1. - neg_fidelity)) return 1. / denom
[docs]def fidelity_curve_auc( pos_fidelity: Tensor, neg_fidelity: Tensor, x: Tensor, ) -> Tensor: r"""Returns the AUC for the fidelity curve as described in the `"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" <https://arxiv.org/abs/2206.09677>`_ paper. More precisely, returns the AUC of .. math:: f(x) = \frac{\textrm{fid}_{+}}{1 - \textrm{fid}_{-}} Args: pos_fidelity (torch.Tensor): The positive fidelity :math:`\textrm{fid}_{+}`. neg_fidelity (torch.Tensor): The negative fidelity :math:`\textrm{fid}_{-}`. x (torch.Tensor): Tensor containing the points on the :math:`x`-axis. Needs to be sorted in ascending order. """ if torch.any(neg_fidelity == 1): raise ValueError("There exists negative fidelity values containing 1, " "leading to a division by zero") y = pos_fidelity / (1. - neg_fidelity) return auc(x, y)
def auc(x: Tensor, y: Tensor) -> Tensor: if torch.any(x.diff() < 0): raise ValueError("'x' must be given in ascending order") return torch.trapezoid(y, x)