Source code for torch_geometric.explain.algorithm.gnn_explainer

import logging
from math import sqrt
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter

from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import (
    ExplainerConfig,
    MaskType,
    ModelConfig,
    ModelMode,
    ModelTaskLevel,
)

from .base import ExplainerAlgorithm


[docs]class GNNExplainer(ExplainerAlgorithm): r"""The GNN-Explainer model from the `"GNNExplainer: Generating Explanations for Graph Neural Networks" <https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN. The following configurations are currently supported: - :class:`torch_geometric.explain.config.ModelConfig` - :attr:`task_level`: :obj:`"node"`, :obj:`"edge"`, or :obj:`"graph"` - :class:`torch_geometric.explain.config.ExplainerConfig` - :attr:`node_mask_type`: :obj:`"object"`, :obj:`"common_attributes"` or :obj:`"attributes"` - :attr:`edge_mask_type`: :obj:`"object"` or :obj:`None` .. note:: For an example of using :class:`GNNExplainer`, see `examples/gnn_explainer.py <https://github.com/pyg-team/ pytorch_geometric/blob/master/examples/gnn_explainer.py>`_ and `examples/gnn_explainer_ba_shapes.py <https://github.com/pyg-team/ pytorch_geometric/blob/master/examples/gnn_explainer_ba_shapes.py>`_. Args: epochs (int, optional): The number of epochs to train. (default: :obj:`100`) lr (float, optional): The learning rate to apply. (default: :obj:`0.01`) **kwargs (optional): Additional hyper-parameters to override default settings in :attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`. """ coeffs = { 'edge_size': 0.005, 'edge_reduction': 'sum', 'node_feat_size': 1.0, 'node_feat_reduction': 'mean', 'edge_ent': 1.0, 'node_feat_ent': 0.1, 'EPS': 1e-15, } def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs): super().__init__() self.epochs = epochs self.lr = lr self.coeffs.update(kwargs) self.node_mask = self.edge_mask = None def supports(self) -> bool: task_level = self.model_config.task_level if task_level not in [ ModelTaskLevel.node, ModelTaskLevel.edge, ModelTaskLevel.graph ]: logging.error(f"Task level '{task_level.value}' not supported") return False edge_mask_type = self.explainer_config.edge_mask_type if edge_mask_type not in [MaskType.object, None]: logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported") return False node_mask_type = self.explainer_config.node_mask_type if node_mask_type not in [ MaskType.common_attributes, MaskType.object, MaskType.attributes ]: logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.") return False return True def forward( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, target_index: Optional[int] = None, **kwargs, ) -> Explanation: hard_node_mask = hard_edge_mask = None if self.model_config.task_level == ModelTaskLevel.node: # We need to compute hard masks to properly clean up edges and # nodes attributions not involved during message passing: hard_node_mask, hard_edge_mask = self._get_hard_masks( model, index, edge_index, num_nodes=x.size(0)) self._train(model, x, edge_index, target=target, index=index, target_index=target_index, **kwargs) node_mask = self._post_process_mask(self.node_mask, x.size(0), hard_node_mask, apply_sigmoid=True) edge_mask = self._post_process_mask(self.edge_mask, edge_index.size(1), hard_edge_mask, apply_sigmoid=True) self._clean_model(model) # TODO Consider dropping differentiation between `mask` and `feat_mask` node_feat_mask = None if self.explainer_config.node_mask_type in { MaskType.attributes, MaskType.common_attributes }: node_feat_mask, node_mask = node_mask, None return Explanation(x=x, edge_index=edge_index, edge_mask=edge_mask, node_mask=node_mask, node_feat_mask=node_feat_mask) def _train( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, target_index: Optional[int] = None, **kwargs, ): self._initialize_masks(x, edge_index) parameters = [self.node_mask] # We always learn a node mask. if self.explainer_config.edge_mask_type is not None: set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True) parameters.append(self.edge_mask) optimizer = torch.optim.Adam(parameters, lr=self.lr) for _ in range(self.epochs): optimizer.zero_grad() h = x * self.node_mask.sigmoid() y_hat, y = model(h, edge_index, **kwargs), target if target_index is not None: y_hat, y = y_hat[target_index], y[target_index] if index is not None: y_hat, y = y_hat[index], y[index] loss = self._loss(y_hat, y) loss.backward() optimizer.step() def _initialize_masks(self, x: Tensor, edge_index: Tensor): node_mask_type = self.explainer_config.node_mask_type edge_mask_type = self.explainer_config.edge_mask_type device = x.device (N, F), E = x.size(), edge_index.size(1) std = 0.1 if node_mask_type == MaskType.object: self.node_mask = Parameter(torch.randn(N, 1, device=device) * std) elif node_mask_type == MaskType.attributes: self.node_mask = Parameter(torch.randn(N, F, device=device) * std) elif node_mask_type == MaskType.common_attributes: self.node_mask = Parameter(torch.randn(1, F, device=device) * std) else: raise NotImplementedError if edge_mask_type == MaskType.object: std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N)) self.edge_mask = Parameter(torch.randn(E, device=device) * std) elif edge_mask_type is not None: raise NotImplementedError def _loss_regression(self, y_hat: Tensor, y: Tensor) -> Tensor: return F.mse_loss(y_hat, y) def _loss_classification(self, y_hat: Tensor, y: Tensor) -> Tensor: if y.dim() == 0: # `index` was given as an integer. y_hat, y = y_hat.unsqueeze(0), y.unsqueeze(0) y_hat = self._to_log_prob(y_hat) return (-y_hat).gather(1, y.view(-1, 1)).mean() def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor: if self.model_config.mode == ModelMode.regression: loss = self._loss_regression(y_hat, y) elif self.model_config.mode == ModelMode.classification: loss = self._loss_classification(y_hat, y) else: raise NotImplementedError if self.explainer_config.edge_mask_type is not None: m = self.edge_mask.sigmoid() edge_reduce = getattr(torch, self.coeffs['edge_reduction']) loss = loss + self.coeffs['edge_size'] * edge_reduce(m) ent = -m * torch.log(m + self.coeffs['EPS']) - ( 1 - m) * torch.log(1 - m + self.coeffs['EPS']) loss = loss + self.coeffs['edge_ent'] * ent.mean() m = self.node_mask.sigmoid() node_feat_reduce = getattr(torch, self.coeffs['node_feat_reduction']) loss = loss + self.coeffs['node_feat_size'] * node_feat_reduce(m) ent = -m * torch.log(m + self.coeffs['EPS']) - ( 1 - m) * torch.log(1 - m + self.coeffs['EPS']) loss = loss + self.coeffs['node_feat_ent'] * ent.mean() return loss def _clean_model(self, model): clear_masks(model) self.node_mask = None self.edge_mask = None
class GNNExplainer_: r"""Deprecated version for :class:`GNNExplainer`.""" coeffs = GNNExplainer.coeffs conversion_node_mask_type = { 'feature': 'common_attributes', 'individual_feature': 'attributes', 'scalar': 'object', } conversion_return_type = { 'log_prob': 'log_probs', 'prob': 'probs', 'raw': 'raw', 'regression': 'raw', } def __init__( self, model: torch.nn.Module, epochs: int = 100, lr: float = 0.01, return_type: str = 'log_prob', feat_mask_type: str = 'feature', allow_edge_mask: bool = True, **kwargs, ): assert feat_mask_type in ['feature', 'individual_feature', 'scalar'] explainer_config = ExplainerConfig( explanation_type='model', node_mask_type=self.conversion_node_mask_type[feat_mask_type], edge_mask_type=MaskType.object if allow_edge_mask else None, ) model_config = ModelConfig( mode='regression' if return_type == 'regression' else 'classification', task_level=ModelTaskLevel.node, return_type=self.conversion_return_type[return_type], ) self.model = model self._explainer = GNNExplainer(epochs=epochs, lr=lr, **kwargs) self._explainer.connect(explainer_config, model_config) @torch.no_grad() def get_initial_prediction(self, *args, **kwargs) -> Tensor: training = self.model.training self.model.eval() out = self.model(*args, **kwargs) if self._explainer.model_config.mode == ModelMode.classification: out = out.argmax(dim=-1) self.model.train(training) return out def explain_graph( self, x: Tensor, edge_index: Tensor, **kwargs, ) -> Tuple[Tensor, Tensor]: self._explainer.model_config.task_level = ModelTaskLevel.graph explanation = self._explainer( self.model, x, edge_index, target=self.get_initial_prediction(x, edge_index, **kwargs), **kwargs, ) return self._convert_output(explanation, edge_index) def explain_node( self, node_idx: int, x: Tensor, edge_index: Tensor, **kwargs, ) -> Tuple[Tensor, Tensor]: self._explainer.model_config.task_level = ModelTaskLevel.node explanation = self._explainer( self.model, x, edge_index, target=self.get_initial_prediction(x, edge_index, **kwargs), index=node_idx, **kwargs, ) return self._convert_output(explanation, edge_index, index=node_idx, x=x) def _convert_output(self, explanation, edge_index, index=None, x=None): if 'node_mask' in explanation.available_explanations: node_mask = explanation.node_mask else: if (self._explainer.explainer_config.node_mask_type == MaskType.common_attributes): node_mask = explanation.node_feat_mask[0] else: node_mask = explanation.node_feat_mask edge_mask = None if 'edge_mask' in explanation.available_explanations: edge_mask = explanation.edge_mask else: if index is not None: _, edge_mask = self._explainer._get_hard_masks( self.model, index, edge_index, num_nodes=x.size(0)) edge_mask = edge_mask.to(x.dtype) else: edge_mask = torch.ones(edge_index.shape[1], device=edge_index.device) return node_mask, edge_mask