import logging
from typing import List, Optional, Union

import torch
from torch import Tensor

from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.config import ExplanationType, ModelTaskLevel
from torch_geometric.nn.conv.message_passing import MessagePassing

[docs]class AttentionExplainer(ExplainerAlgorithm): r"""An explainer that uses the attention coefficients produced by an attention-based GNN (*e.g.*, :class:`~torch_geometric.nn.conv.GATConv`, :class:`~torch_geometric.nn.conv.GATv2Conv`, or :class:`~torch_geometric.nn.conv.TransformerConv`) as edge explanation. Attention scores across layers and heads will be aggregated according to the :obj:`reduce` argument. Args: reduce (str, optional): The method to reduce the attention scores across layers and heads. (default: :obj:`"max"`) """ def __init__(self, reduce: str = 'max'): super().__init__() self.reduce = reduce
[docs] def forward( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Explanation: if isinstance(x, dict) or isinstance(edge_index, dict): raise ValueError(f"Heterogeneous graphs not yet supported in " f"'{self.__class__.__name__}'") hard_edge_mask = None if self.model_config.task_level == ModelTaskLevel.node: # We need to compute the hard edge mask to properly clean up edge # attributions not involved during message passing: _, hard_edge_mask = self._get_hard_masks(model, index, edge_index, num_nodes=x.size(0)) alphas: List[Tensor] = [] def hook(module, msg_kwargs, out): if 'alpha' in msg_kwargs[0]: alphas.append(msg_kwargs[0]['alpha'].detach()) elif getattr(module, '_alpha', None) is not None: alphas.append(module._alpha.detach()) hook_handles = [] for module in model.modules(): # Register message forward hooks: if (isinstance(module, MessagePassing) and module.explain is not False): hook_handles.append(module.register_message_forward_hook(hook)) model(x, edge_index, **kwargs) for handle in hook_handles: # Remove hooks: handle.remove() if len(alphas) == 0: raise ValueError("Could not collect any attention coefficients. " "Please ensure that your model is using " "attention-based GNN layers.") for i, alpha in enumerate(alphas): alpha = alpha[:edge_index.size(1)] # Respect potential self-loops. if alpha.dim() == 2: alpha = getattr(torch, self.reduce)(alpha, dim=-1) if isinstance(alpha, tuple): # Respect `torch.max`: alpha = alpha[0] elif alpha.dim() > 2: raise ValueError(f"Can not reduce attention coefficients of " f"shape {list(alpha.size())}") alphas[i] = alpha if len(alphas) > 1: alpha = torch.stack(alphas, dim=-1) alpha = getattr(torch, self.reduce)(alpha, dim=-1) if isinstance(alpha, tuple): # Respect `torch.max`: alpha = alpha[0] else: alpha = alphas[0] alpha = self._post_process_mask(alpha, hard_edge_mask, apply_sigmoid=False) return Explanation(edge_mask=alpha)
[docs] def supports(self) -> bool: explanation_type = self.explainer_config.explanation_type if explanation_type != ExplanationType.model: logging.error(f"'{self.__class__.__name__}' only supports " f"model explanations " f"got (`explanation_type={explanation_type.value}`)") return False node_mask_type = self.explainer_config.node_mask_type if node_mask_type is not None: logging.error(f"'{self.__class__.__name__}' does not support " f"explaining input node features " f"got (`node_mask_type={node_mask_type.value}`)") return False return True