# Source code for torch_geometric.explain.metric.fidelity

from typing import Tuple

import torch
from torch import Tensor

from torch_geometric.explain import Explainer, Explanation
from torch_geometric.explain.config import ExplanationType, ModelMode

[docs]def fidelity(
explainer: Explainer,
explanation: Explanation,
) -> Tuple[float, float]:
r"""Evaluates the fidelity of an
:class:~torch_geometric.explain.Explainer given an
:class:~torch_geometric.explain.Explanation, as described in the
"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for
Graph Neural Networks" <https://arxiv.org/abs/2206.09677>_ paper.

Fidelity evaluates the contribution of the produced explanatory subgraph
to the initial prediction, either by giving only the subgraph to the model
(fidelity-) or by removing it from the entire graph (fidelity+).
The fidelity scores capture how good an explainable model reproduces the
natural phenomenon or the GNN model logic.

For **phenomenon** explanations, the fidelity scores are given by:

.. math::
\textrm{fid}_{+} &= \frac{1}{N} \sum_{i = 1}^N
\| \mathbb{1}(\hat{y}_i = y_i) -
\mathbb{1}( \hat{y}_i^{G_{C \setminus S}} = y_i) \|

\textrm{fid}_{-} &= \frac{1}{N} \sum_{i = 1}^N
\| \mathbb{1}(\hat{y}_i = y_i) -
\mathbb{1}( \hat{y}_i^{G_S} = y_i) \|

For **model** explanations, the fidelity scores are given by:

.. math::
\textrm{fid}_{+} &= 1 - \frac{1}{N} \sum_{i = 1}^N
\mathbb{1}( \hat{y}_i^{G_{C \setminus S}} = \hat{y}_i)

\textrm{fid}_{-} &= 1 - \frac{1}{N} \sum_{i = 1}^N
\mathbb{1}( \hat{y}_i^{G_S} = \hat{y}_i)

Args:
explainer (Explainer): The explainer to evaluate.
explanation (Explanation): The explanation to evaluate.
"""
if explainer.model_config.mode == ModelMode.regression:
raise ValueError("Fidelity not defined for 'regression' models")

kwargs = {key: explanation[key] for key in explanation._model_args}

y = explanation.target
if explainer.explanation_type == ExplanationType.phenomenon:
y_hat = explainer.get_prediction(
explanation.x,
explanation.edge_index,
**kwargs,
)
y_hat = explainer.get_target(y_hat)

explanation.x,
explanation.edge_index,
**kwargs,
)
explain_y_hat = explainer.get_target(explain_y_hat)

explanation.x,
explanation.edge_index,
**kwargs,
)
complement_y_hat = explainer.get_target(complement_y_hat)

if explanation.get('index') is not None:
y = y[explanation.index]
if explainer.explanation_type == ExplanationType.phenomenon:
y_hat = y_hat[explanation.index]
explain_y_hat = explain_y_hat[explanation.index]
complement_y_hat = complement_y_hat[explanation.index]

if explainer.explanation_type == ExplanationType.model:
pos_fidelity = 1. - (complement_y_hat == y).float().mean()
neg_fidelity = 1. - (explain_y_hat == y).float().mean()
else:
pos_fidelity = ((y_hat == y).float() -
(complement_y_hat == y).float()).abs().mean()
neg_fidelity = ((y_hat == y).float() -
(explain_y_hat == y).float()).abs().mean()

return float(pos_fidelity), float(neg_fidelity)

[docs]def characterization_score(
pos_fidelity: Tensor,
neg_fidelity: Tensor,
pos_weight: float = 0.5,
neg_weight: float = 0.5,
) -> Tensor:
r"""Returns the componentwise characterization score as described in the
"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for
Graph Neural Networks" <https://arxiv.org/abs/2206.09677>_ paper.

..  math::
\textrm{charact} = \frac{w_{+} + w_{-}}{\frac{w_{+}}{\textrm{fid}_{+}} +
\frac{w_{-}}{1 - \textrm{fid}_{-}}}

Args:
pos_fidelity (torch.Tensor): The positive fidelity
:math:\textrm{fid}_{+}.
neg_fidelity (torch.Tensor): The negative fidelity
:math:\textrm{fid}_{-}.
pos_weight (float, optional): The weight :math:w_{+} for
:math:\textrm{fid}_{+}. (default: :obj:0.5)
neg_weight (float, optional): The weight :math:w_{-} for
:math:\textrm{fid}_{-}. (default: :obj:0.5)
"""
if (pos_weight + neg_weight) != 1.0:
raise ValueError(f"The weights need to sum up to 1 "
f"(got {pos_weight} and {neg_weight})")

denom = (pos_weight / pos_fidelity) + (neg_weight / (1. - neg_fidelity))
return 1. / denom

[docs]def fidelity_curve_auc(
pos_fidelity: Tensor,
neg_fidelity: Tensor,
x: Tensor,
) -> Tensor:
r"""Returns the AUC for the fidelity curve as described in the
"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for
Graph Neural Networks" <https://arxiv.org/abs/2206.09677>_ paper.

More precisely, returns the AUC of

.. math::
f(x) = \frac{\textrm{fid}_{+}}{1 - \textrm{fid}_{-}}

Args:
pos_fidelity (torch.Tensor): The positive fidelity
:math:\textrm{fid}_{+}.
neg_fidelity (torch.Tensor): The negative fidelity
:math:\textrm{fid}_{-}.
x (torch.Tensor): Tensor containing the points on the :math:x-axis.
Needs to be sorted in ascending order.
"""
if torch.any(neg_fidelity == 1):
raise ValueError("There exists negative fidelity values containing 1, "
"leading to a division by zero")

y = pos_fidelity / (1. - neg_fidelity)
return auc(x, y)

def auc(x: Tensor, y: Tensor) -> Tensor:
if torch.any(x.diff() < 0):
raise ValueError("'x' must be given in ascending order")