Source code for torch_geometric.explain.explainer

import copy
import warnings
from typing import Optional, Union

import torch
from torch import Tensor

from torch_geometric.explain import ExplainerAlgorithm, Explanation
from torch_geometric.explain.config import (
    ExplainerConfig,
    ExplanationType,
    ModelConfig,
    ModelMode,
    ThresholdConfig,
    ThresholdType,
)


[docs]class Explainer: r"""An explainer class for instance-level explanations of Graph Neural Networks. Args: model (torch.nn.Module): The model to explain. algorithm (ExplainerAlgorithm): The explanation algorithm. explainer_config (ExplainerConfig): The explainer configuration. model_config (ModelConfig): The model configuration. threshold_config (ThresholdConfig, optional): The threshold configuration. (default: :obj:`None`) """ def __init__( self, model: torch.nn.Module, algorithm: ExplainerAlgorithm, explainer_config: ExplainerConfig, model_config: ModelConfig, threshold_config: Optional[ThresholdConfig] = None, ): self.model = model self.algorithm = algorithm self.explainer_config = ExplainerConfig.cast(explainer_config) self.model_config = ModelConfig.cast(model_config) self.threshold_config = ThresholdConfig.cast(threshold_config) self.algorithm.connect(self.explainer_config, self.model_config)
[docs] @torch.no_grad() def get_prediction(self, *args, **kwargs) -> torch.Tensor: r"""Returns the prediction of the model on the input graph. If the model mode is :obj:`"regression"`, the prediction is returned as a scalar value. If the model mode :obj:`"classification"`, the prediction is returned as the predicted class label. Args: *args: Arguments passed to the model. **kwargs (optional): Additional keyword arguments passed to the model. """ training = self.model.training self.model.eval() with torch.no_grad(): out = self.model(*args, **kwargs) if self.model_config.mode == ModelMode.classification: out = out.argmax(dim=-1) self.model.train(training) return out
[docs] def __call__( self, x: Tensor, edge_index: Tensor, *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, target_index: Optional[int] = None, **kwargs, ) -> Explanation: r"""Computes the explanation of the GNN for the given inputs and target. .. note:: If you get an error message like "Trying to backward through the graph a second time", make sure that the target you provided was computed with :meth:`torch.no_grad`. Args: x (torch.Tensor): The input node features. edge_index (torch.Tensor): The input edge indices. target (torch.Tensor): The target of the model. If the explanation type is :obj:`"phenomenon"`, the target has to be provided. If the explanation type is :obj:`"model"`, the target should be set to :obj:`None` and will get automatically inferred. (default: :obj:`None`) index (Union[int, Tensor], optional): The index of the model output to explain. Can be a single index or a tensor of indices. (default: :obj:`None`) target_index (int, optional): The index of the model outputs to reference in case the model returns a list of tensors, *e.g.*, in a multi-task learning scenario. Should be kept to :obj:`None` in case the model only returns a single output tensor. (default: :obj:`None`) **kwargs: additional arguments to pass to the GNN. """ # Choose the `target` depending on the explanation type: explanation_type = self.explainer_config.explanation_type if explanation_type == ExplanationType.phenomenon: if target is None: raise ValueError( f"The 'target' has to be provided for the explanation " f"type '{explanation_type.value}'") elif explanation_type == ExplanationType.model: if target is not None: warnings.warn( f"The 'target' should not be provided for the explanation " f"type '{explanation_type.value}'") target = self.get_prediction(x=x, edge_index=edge_index, **kwargs) training = self.model.training self.model.eval() explanation = self.algorithm( self.model, x, edge_index, target=target, index=index, target_index=target_index, **kwargs, ) self.model.train(training) return self._post_process(explanation)
def _post_process(self, explanation: Explanation) -> Explanation: R"""Post-processes the explanation mask according to the thresholding method and the user configuration. Args: explanation (Explanation): The explanation mask to post-process. """ explanation = self._threshold(explanation) return explanation def _threshold(self, explanation: Explanation) -> Explanation: """Threshold the explanation mask according to the thresholding method. Args: explanation (Explanation): The explanation to threshold. """ if self.threshold_config is None: return explanation # Avoid modification of the original explanation: explanation = copy.copy(explanation) mask_dict = { # Get the available masks: key: explanation[key] for key in explanation.available_explanations } if self.threshold_config.type == ThresholdType.hard: mask_dict = { key: (mask > self.threshold_config.value).float() for key, mask in mask_dict.items() } elif self.threshold_config.type in [ ThresholdType.topk, ThresholdType.topk_hard, ]: for key, mask in mask_dict.items(): if self.threshold_config.value >= mask.numel(): if self.threshold_config.type != ThresholdType.topk: mask_dict[key] = torch.ones_like(mask) continue value, index = torch.topk( mask.flatten(), k=self.threshold_config.value, ) out = torch.zeros_like(mask.flatten()) if self.threshold_config.type == ThresholdType.topk: out[index] = value else: out[index] = 1.0 mask_dict[key] = out.reshape(mask.size()) else: raise NotImplementedError # Update the explanation with the thresholded masks: for key, mask in mask_dict.items(): explanation[key] = mask return explanation