import inspect
import logging
import warnings
from typing import Any, Dict, Optional, Union

import torch
from torch import Tensor

from torch_geometric.explain import Explanation, HeteroExplanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.algorithm.captum import (
from torch_geometric.explain.config import MaskType, ModelMode, ModelReturnType
from torch_geometric.typing import EdgeType, NodeType

[docs]class CaptumExplainer(ExplainerAlgorithm): """A `Captum <>`__-based explainer for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN. This explainer algorithm uses :captum:`null` `Captum <>`_ to compute attributions. Currently, the following attribution methods are supported: * :class:`captum.attr.IntegratedGradients` * :class:`captum.attr.Saliency` * :class:`captum.attr.InputXGradient` * :class:`captum.attr.Deconvolution` * :class:`captum.attr.ShapleyValueSampling` * :class:`captum.attr.GuidedBackprop` Args: attribution_method (Attribution or str): The Captum attribution method to use. Can be a string or a :class:`captum.attr` method. **kwargs: Additional arguments for the Captum attribution method. """ SUPPORTED_METHODS = [ # TODO: Add support for more methods. 'IntegratedGradients', 'Saliency', 'InputXGradient', 'Deconvolution', 'ShapleyValueSampling', 'GuidedBackprop', ] def __init__( self, attribution_method: Union[str, Any], **kwargs, ): super().__init__() import captum.attr if isinstance(attribution_method, str): self.attribution_method_class = getattr( captum.attr, attribution_method, ) else: self.attribution_method_class = attribution_method if not self._is_supported_attribution_method(): raise ValueError(f"{self.__class__.__name__} does not support " f"attribution method " f"{self.attribution_method_class.__name__}") if kwargs.get('internal_batch_size', 1) != 1: warnings.warn("Overriding 'internal_batch_size' to 1") if 'internal_batch_size' in self._get_attribute_parameters(): kwargs['internal_batch_size'] = 1 self.kwargs = kwargs def _get_mask_type(self) -> MaskLevelType: r"""Based on the explainer config, return the mask type.""" node_mask_type = self.explainer_config.node_mask_type edge_mask_type = self.explainer_config.edge_mask_type if node_mask_type is not None and edge_mask_type is not None: mask_type = MaskLevelType.node_and_edge elif node_mask_type is not None: mask_type = MaskLevelType.node elif edge_mask_type is not None: mask_type = MaskLevelType.edge else: raise ValueError("Neither node mask type nor " "edge mask type is specified.") return mask_type def _get_attribute_parameters(self) -> Dict[str, Any]: r"""Returns the attribute arguments.""" signature = inspect.signature(self.attribution_method_class.attribute) return signature.parameters def _needs_baseline(self) -> bool: r"""Checks if the method needs a baseline.""" parameters = self._get_attribute_parameters() if 'baselines' in parameters: param = parameters['baselines'] if param.default is inspect.Parameter.empty: return True return False def _is_supported_attribution_method(self) -> bool: r"""Returns :obj:`True` if `self.attribution_method` is supported.""" # This is redundant for now since all supported methods need a baseline if self._needs_baseline(): return False elif self.attribution_method_class.__name__ in self.SUPPORTED_METHODS: return True return False
[docs] def forward( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Union[Explanation, HeteroExplanation]: mask_type = self._get_mask_type() inputs, add_forward_args = to_captum_input( x, edge_index, mask_type, *kwargs.values(), ) if isinstance(x, dict): # Heterogeneous GNN: metadata = (list(x.keys()), list(edge_index.keys())) captum_model = CaptumHeteroModel( model, mask_type, index, metadata, self.model_config, ) else: # Homogeneous GNN: metadata = None captum_model = CaptumModel( model, mask_type, index, self.model_config, ) self.attribution_method_instance = self.attribution_method_class( captum_model) # In Captum, the target is the class index for which the attribution is # computed. Within CaptumModel, we transform the binary classification # into a multi-class classification task. if self.model_config.mode == ModelMode.regression: target = None elif index is not None: target = target[index] attributions = self.attribution_method_instance.attribute( inputs=inputs, target=target, additional_forward_args=add_forward_args, **self.kwargs, ) node_mask, edge_mask = convert_captum_output( attributions, mask_type, metadata, ) if not isinstance(x, dict): return Explanation(node_mask=node_mask, edge_mask=edge_mask) explanation = HeteroExplanation() explanation.set_value_dict('node_mask', node_mask) explanation.set_value_dict('edge_mask', edge_mask) return explanation
[docs] def supports(self) -> bool: node_mask_type = self.explainer_config.node_mask_type if node_mask_type not in [None, MaskType.attributes]: logging.error(f"'{self.__class__.__name__}' expects " f"'node_mask_type' to be 'None' or 'attributes' " f"(got '{node_mask_type.value}')") return False return_type = self.model_config.return_type if (self.model_config.mode == ModelMode.binary_classification and return_type != ModelReturnType.probs): logging.error(f"'{self.__class__.__name__}' expects " f"'return_type' to be 'probs' for binary " f"classification tasks (got '{return_type.value}')") return False # TODO (ramona) Confirm that output type is valid. return True