Source code for torch_geometric.explain.algorithm.pg_explainer

import logging
from typing import Optional, Union

import torch
from torch import Tensor
from torch.nn import ReLU, Sequential

from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import (
    ExplanationType,
    ModelMode,
    ModelTaskLevel,
)
from torch_geometric.nn import Linear
from torch_geometric.nn.inits import reset
from torch_geometric.utils import get_embeddings


[docs]class PGExplainer(ExplainerAlgorithm): r"""The PGExplainer model from the `"Parameterized Explainer for Graph Neural Network" <https://arxiv.org/abs/2011.04573>`_ paper. Internally, it utilizes a neural network to identify subgraph structures that play a crucial role in the predictions made by a GNN. Importantly, the :class:`PGExplainer` needs to be trained via :meth:`~PGExplainer.train` before being able to generate explanations: .. code-block:: python explainer = Explainer( model=model, algorithm=PGExplainer(epochs=30, lr=0.003), explanation_type='phenomenon', edge_mask_type='object', model_config=ModelConfig(...), ) # Train against a variety of node-level or graph-level predictions: for epoch in range(30): for index in [...]: # Indices to train against. loss = explainer.algorithm.train(epoch, model, x, edge_index, target=target, index=index) # Get the final explanations: explanation = explainer(x, edge_index, target=target, index=0) Args: epochs (int): The number of epochs to train. lr (float, optional): The learning rate to apply. (default: :obj:`0.003`). **kwargs (optional): Additional hyper-parameters to override default settings in :attr:`~torch_geometric.explain.algorithm.PGExplainer.coeffs`. """ coeffs = { 'edge_size': 0.05, 'edge_ent': 1.0, 'temp': [5.0, 2.0], 'bias': 0.0, } def __init__(self, epochs: int, lr: float = 0.003, **kwargs): super().__init__() self.epochs = epochs self.lr = lr self.coeffs.update(kwargs) self.mlp = Sequential( Linear(-1, 64), ReLU(), Linear(64, 1), ) self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr) self._curr_epoch = -1
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.mlp)
[docs] def train( self, epoch: int, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ): r"""Trains the underlying explainer model. Needs to be called before being able to make predictions. Args: epoch (int): The current epoch of the training phase. model (torch.nn.Module): The model to explain. x (torch.Tensor): The input node features of a homogeneous graph. edge_index (torch.Tensor): The input edge indices of a homogeneous graph. target (torch.Tensor): The target of the model. index (int or torch.Tensor, optional): The index of the model output to explain. Needs to be a single index. (default: :obj:`None`) **kwargs (optional): Additional keyword arguments passed to :obj:`model`. """ if isinstance(x, dict) or isinstance(edge_index, dict): raise ValueError(f"Heterogeneous graphs not yet supported in " f"'{self.__class__.__name__}'") if self.model_config.task_level == ModelTaskLevel.node: if index is None: raise ValueError(f"The 'index' argument needs to be provided " f"in '{self.__class__.__name__}' for " f"node-level explanations") if isinstance(index, Tensor) and index.numel() > 1: raise ValueError(f"Only scalars are supported for the 'index' " f"argument in '{self.__class__.__name__}'") z = get_embeddings(model, x, edge_index, **kwargs)[-1] self.optimizer.zero_grad() temperature = self._get_temperature(epoch) inputs = self._get_inputs(z, edge_index, index) logits = self.mlp(inputs).view(-1) edge_mask = self._concrete_sample(logits, temperature) set_masks(model, edge_mask, edge_index, apply_sigmoid=True) if self.model_config.task_level == ModelTaskLevel.node: _, hard_edge_mask = self._get_hard_masks(model, index, edge_index, num_nodes=x.size(0)) edge_mask = edge_mask[hard_edge_mask] y_hat, y = model(x, edge_index, **kwargs), target if index is not None: y_hat, y = y_hat[index], y[index] loss = self._loss(y_hat, y, edge_mask) loss.backward() self.optimizer.step() clear_masks(model) self._curr_epoch = epoch return float(loss)
[docs] def forward( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Explanation: if isinstance(x, dict) or isinstance(edge_index, dict): raise ValueError(f"Heterogeneous graphs not yet supported in " f"'{self.__class__.__name__}'") if self._curr_epoch < self.epochs - 1: # Safety check: raise ValueError(f"'{self.__class__.__name__}' is not yet fully " f"trained (got {self._curr_epoch + 1} epochs " f"from {self.epochs} epochs). Please first train " f"the underlying explainer model by running " f"`explainer.algorithm.train(...)`.") hard_edge_mask = None if self.model_config.task_level == ModelTaskLevel.node: if index is None: raise ValueError(f"The 'index' argument needs to be provided " f"in '{self.__class__.__name__}' for " f"node-level explanations") if isinstance(index, Tensor) and index.numel() > 1: raise ValueError(f"Only scalars are supported for the 'index' " f"argument in '{self.__class__.__name__}'") # We need to compute hard masks to properly clean up edges and # nodes attributions not involved during message passing: _, hard_edge_mask = self._get_hard_masks(model, index, edge_index, num_nodes=x.size(0)) z = get_embeddings(model, x, edge_index, **kwargs)[-1] inputs = self._get_inputs(z, edge_index, index) logits = self.mlp(inputs).view(-1) edge_mask = self._post_process_mask(logits, hard_edge_mask, apply_sigmoid=True) return Explanation(edge_mask=edge_mask)
[docs] def supports(self) -> bool: explanation_type = self.explainer_config.explanation_type if explanation_type != ExplanationType.phenomenon: logging.error(f"'{self.__class__.__name__}' only supports " f"phenomenon explanations " f"got (`explanation_type={explanation_type.value}`)") return False task_level = self.model_config.task_level if task_level not in {ModelTaskLevel.node, ModelTaskLevel.graph}: logging.error(f"'{self.__class__.__name__}' only supports " f"node-level or graph-level explanations " f"got (`task_level={task_level.value}`)") return False node_mask_type = self.explainer_config.node_mask_type if node_mask_type is not None: logging.error(f"'{self.__class__.__name__}' does not support " f"explaining input node features " f"got (`node_mask_type={node_mask_type.value}`)") return False return True
########################################################################### def _get_inputs(self, embedding: Tensor, edge_index: Tensor, index: Optional[int] = None) -> Tensor: zs = [embedding[edge_index[0]], embedding[edge_index[1]]] if self.model_config.task_level == ModelTaskLevel.node: assert index is not None zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1)) return torch.cat(zs, dim=-1) def _get_temperature(self, epoch: int) -> float: temp = self.coeffs['temp'] return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs) def _concrete_sample(self, logits: Tensor, temperature: float = 1.0) -> Tensor: bias = self.coeffs['bias'] eps = (1 - 2 * bias) * torch.rand_like(logits) + bias return (eps.log() - (1 - eps).log() + logits) / temperature def _loss(self, y_hat: Tensor, y: Tensor, edge_mask: Tensor) -> Tensor: if self.model_config.mode == ModelMode.binary_classification: loss = self._loss_binary_classification(y_hat, y) elif self.model_config.mode == ModelMode.multiclass_classification: loss = self._loss_multiclass_classification(y_hat, y) elif self.model_config.mode == ModelMode.regression: loss = self._loss_regression(y_hat, y) # Regularization loss: mask = edge_mask.sigmoid() size_loss = mask.sum() * self.coeffs['edge_size'] mask = 0.99 * mask + 0.005 mask_ent = -mask * mask.log() - (1 - mask) * (1 - mask).log() mask_ent_loss = mask_ent.mean() * self.coeffs['edge_ent'] return loss + size_loss + mask_ent_loss