import warnings
from typing import Any, Dict, Optional, Union
import torch
from torch import Tensor
from torch_geometric.explain import (
ExplainerAlgorithm,
Explanation,
HeteroExplanation,
)
from torch_geometric.explain.algorithm.utils import (
clear_masks,
set_hetero_masks,
set_masks,
)
from torch_geometric.explain.config import (
ExplainerConfig,
ExplanationType,
MaskType,
ModelConfig,
ModelMode,
ModelReturnType,
ThresholdConfig,
)
from torch_geometric.typing import EdgeType, NodeType
[docs]class Explainer:
r"""An explainer class for instance-level explanations of Graph Neural
Networks.
Args:
model (torch.nn.Module): The model to explain.
algorithm (ExplainerAlgorithm): The explanation algorithm.
explanation_type (ExplanationType or str): The type of explanation to
compute. The possible values are:
- :obj:`"model"`: Explains the model prediction.
- :obj:`"phenomenon"`: Explains the phenomenon that the model
is trying to predict.
In practice, this means that the explanation algorithm will either
compute their losses with respect to the model output
(:obj:`"model"`) or the target output (:obj:`"phenomenon"`).
model_config (ModelConfig): The model configuration.
See :class:`~torch_geometric.explain.config.ModelConfig` for
available options. (default: :obj:`None`)
node_mask_type (MaskType or str, optional): The type of mask to apply
on nodes. The possible values are (default: :obj:`None`):
- :obj:`None`: Will not apply any mask on nodes.
- :obj:`"object"`: Will mask each node.
- :obj:`"common_attributes"`: Will mask each feature.
- :obj:`"attributes"`: Will mask each feature across all nodes.
edge_mask_type (MaskType or str, optional): The type of mask to apply
on edges. Has the sample possible values as :obj:`node_mask_type`.
(default: :obj:`None`)
threshold_config (ThresholdConfig, optional): The threshold
configuration.
See :class:`~torch_geometric.explain.config.ThresholdConfig` for
available options. (default: :obj:`None`)
"""
def __init__(
self,
model: torch.nn.Module,
algorithm: ExplainerAlgorithm,
explanation_type: Union[ExplanationType, str],
model_config: Union[ModelConfig, Dict[str, Any]],
node_mask_type: Optional[Union[MaskType, str]] = None,
edge_mask_type: Optional[Union[MaskType, str]] = None,
threshold_config: Optional[ThresholdConfig] = None,
):
explainer_config = ExplainerConfig(
explanation_type=explanation_type,
node_mask_type=node_mask_type,
edge_mask_type=edge_mask_type,
)
self.model = model
self.algorithm = algorithm
self.explanation_type = explainer_config.explanation_type
self.model_config = ModelConfig.cast(model_config)
self.node_mask_type = explainer_config.node_mask_type
self.edge_mask_type = explainer_config.edge_mask_type
self.threshold_config = ThresholdConfig.cast(threshold_config)
self.algorithm.connect(explainer_config, self.model_config)
[docs] @torch.no_grad()
def get_prediction(self, *args, **kwargs) -> Tensor:
r"""Returns the prediction of the model on the input graph.
If the model mode is :obj:`"regression"`, the prediction is returned as
a scalar value.
If the model mode is :obj:`"multiclass_classification"` or
:obj:`"binary_classification"`, the prediction is returned as the
predicted class label.
Args:
*args: Arguments passed to the model.
**kwargs (optional): Additional keyword arguments passed to the
model.
"""
training = self.model.training
self.model.eval()
with torch.no_grad():
out = self.model(*args, **kwargs)
self.model.train(training)
return out
[docs] def get_masked_prediction(
self,
x: Union[Tensor, Dict[NodeType, Tensor]],
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
node_mask: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None,
edge_mask: Optional[Union[Tensor, Dict[EdgeType, Tensor]]] = None,
**kwargs,
) -> Tensor:
r"""Returns the prediction of the model on the input graph with node
and edge masks applied.
"""
if isinstance(x, Tensor) and node_mask is not None:
x = node_mask * x
elif isinstance(x, dict) and node_mask is not None:
x = {key: value * node_mask[key] for key, value in x.items()}
if isinstance(edge_mask, Tensor):
set_masks(self.model, edge_mask, edge_index, apply_sigmoid=False)
elif isinstance(edge_mask, dict):
set_hetero_masks(self.model, edge_mask, edge_index,
apply_sigmoid=False)
out = self.get_prediction(x, edge_index, **kwargs)
clear_masks(self.model)
return out
[docs] def __call__(
self,
x: Union[Tensor, Dict[NodeType, Tensor]],
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
*,
target: Optional[Tensor] = None,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
) -> Union[Explanation, HeteroExplanation]:
r"""Computes the explanation of the GNN for the given inputs and
target.
.. note::
If you get an error message like "Trying to backward through the
graph a second time", make sure that the target you provided
was computed with :meth:`torch.no_grad`.
Args:
x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input
node features of a homogeneous or heterogeneous graph.
edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The
input edge indices of a homogeneous or heterogeneous graph.
target (torch.Tensor): The target of the model.
If the explanation type is :obj:`"phenomenon"`, the target has
to be provided.
If the explanation type is :obj:`"model"`, the target should be
set to :obj:`None` and will get automatically inferred. For
classification tasks, the target needs to contain the class
labels. (default: :obj:`None`)
index (Union[int, Tensor], optional): The indices in the
first-dimension of the model output to explain.
Can be a single index or a tensor of indices.
If set to :obj:`None`, all model outputs will be explained.
(default: :obj:`None`)
**kwargs: additional arguments to pass to the GNN.
"""
# Choose the `target` depending on the explanation type:
prediction: Optional[Tensor] = None
if self.explanation_type == ExplanationType.phenomenon:
if target is None:
raise ValueError(
f"The 'target' has to be provided for the explanation "
f"type '{self.explanation_type.value}'")
elif self.explanation_type == ExplanationType.model:
if target is not None:
warnings.warn(
f"The 'target' should not be provided for the explanation "
f"type '{self.explanation_type.value}'")
prediction = self.get_prediction(x, edge_index, **kwargs)
target = self.get_target(prediction)
if isinstance(index, int):
index = torch.tensor([index])
training = self.model.training
self.model.eval()
explanation = self.algorithm(
self.model,
x,
edge_index,
target=target,
index=index,
**kwargs,
)
self.model.train(training)
# Add explainer objectives to the `Explanation` object:
explanation._model_config = self.model_config
explanation.prediction = prediction
explanation.target = target
explanation.index = index
# Add model inputs to the `Explanation` object:
if isinstance(explanation, Explanation):
explanation._model_args = list(kwargs.keys())
explanation.x = x
explanation.edge_index = edge_index
for key, arg in kwargs.items(): # Add remaining `kwargs`:
explanation[key] = arg
elif isinstance(explanation, HeteroExplanation):
# TODO Add `explanation._model_args`
assert isinstance(x, dict)
explanation.set_value_dict('x', x)
assert isinstance(edge_index, dict)
explanation.set_value_dict('edge_index', edge_index)
for key, arg in kwargs.items(): # Add remaining `kwargs`:
if isinstance(arg, dict):
# Keyword arguments are likely named `{attr_name}_dict`
# while we only want to assign the `{attr_name}` to the
# `HeteroExplanation` object:
key = key[:-5] if key.endswith('_dict') else key
explanation.set_value_dict(key, arg)
else:
explanation[key] = arg
explanation.validate_masks()
return explanation.threshold(self.threshold_config)
[docs] def get_target(self, prediction: Tensor) -> Tensor:
r"""Returns the target of the model from a given prediction.
If the model mode is of type :obj:`"regression"`, the prediction is
returned as it is.
If the model mode is of type :obj:`"multiclass_classification"` or
:obj:`"binary_classification"`, the prediction is returned as the
predicted class label.
"""
if self.model_config.mode == ModelMode.binary_classification:
# TODO: Allow customization of the thresholds used below.
if self.model_config.return_type == ModelReturnType.raw:
return (prediction > 0).long().view(-1)
if self.model_config.return_type == ModelReturnType.probs:
return (prediction > 0.5).long().view(-1)
assert False
if self.model_config.mode == ModelMode.multiclass_classification:
return prediction.argmax(dim=-1)
return prediction