Source code for torch_geometric.nn.models.gnn_explainer

from math import sqrt
from typing import Optional

import torch
from tqdm import tqdm

from torch_geometric.nn.models.explainer import (Explainer, clear_masks,
                                                 set_masks)

EPS = 1e-15


[docs]class GNNExplainer(Explainer): r"""The GNN-Explainer model from the `"GNNExplainer: Generating Explanations for Graph Neural Networks" <https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph structures and small subsets node features that play a crucial role in a GNN’s node-predictions. .. note:: For an example of using GNN-Explainer, see `examples/gnn_explainer.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ gnn_explainer.py>`_. Args: model (torch.nn.Module): The GNN module to explain. epochs (int, optional): The number of epochs to train. (default: :obj:`100`) lr (float, optional): The learning rate to apply. (default: :obj:`0.01`) num_hops (int, optional): The number of hops the :obj:`model` is aggregating information from. If set to :obj:`None`, will automatically try to detect this information based on the number of :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` layers inside :obj:`model`. (default: :obj:`None`) return_type (str, optional): Denotes the type of output from :obj:`model`. Valid inputs are :obj:`"log_prob"` (the model returns the logarithm of probabilities), :obj:`"prob"` (the model returns probabilities), :obj:`"raw"` (the model returns raw scores) and :obj:`"regression"` (the model returns scalars). (default: :obj:`"log_prob"`) feat_mask_type (str, optional): Denotes the type of feature mask that will be learned. Valid inputs are :obj:`"feature"` (a single feature-level mask for all nodes), :obj:`"individual_feature"` (individual feature-level masks for each node), and :obj:`"scalar"` (scalar mask for each each node). (default: :obj:`"feature"`) allow_edge_mask (boolean, optional): If set to :obj:`False`, the edge mask will not be optimized. (default: :obj:`True`) log (bool, optional): If set to :obj:`False`, will not log any learning progress. (default: :obj:`True`) **kwargs (optional): Additional hyper-parameters to override default settings in :attr:`~torch_geometric.nn.models.GNNExplainer.coeffs`. """ coeffs = { 'edge_size': 0.005, 'edge_reduction': 'sum', 'node_feat_size': 1.0, 'node_feat_reduction': 'mean', 'edge_ent': 1.0, 'node_feat_ent': 0.1, } def __init__(self, model, epochs: int = 100, lr: float = 0.01, num_hops: Optional[int] = None, return_type: str = 'log_prob', feat_mask_type: str = 'feature', allow_edge_mask: bool = True, log: bool = True, **kwargs): super().__init__(model, lr, epochs, num_hops, return_type, log) assert feat_mask_type in ['feature', 'individual_feature', 'scalar'] self.allow_edge_mask = allow_edge_mask self.feat_mask_type = feat_mask_type self.coeffs.update(kwargs) def _initialize_masks(self, x, edge_index, init="normal"): (N, F), E = x.size(), edge_index.size(1) std = 0.1 if self.feat_mask_type == 'individual_feature': self.node_feat_mask = torch.nn.Parameter(torch.randn(N, F) * std) elif self.feat_mask_type == 'scalar': self.node_feat_mask = torch.nn.Parameter(torch.randn(N, 1) * std) else: self.node_feat_mask = torch.nn.Parameter(torch.randn(1, F) * std) std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N)) if self.allow_edge_mask: self.edge_mask = torch.nn.Parameter(torch.randn(E) * std) def _clear_masks(self): clear_masks(self.model) self.node_feat_masks = None self.edge_mask = None def _loss(self, log_logits, prediction, node_idx: Optional[int] = None): if self.return_type == 'regression': if node_idx is not None and node_idx >= 0: loss = torch.cdist(log_logits[node_idx], prediction[node_idx]) else: loss = torch.cdist(log_logits, prediction) else: if node_idx is not None and node_idx >= 0: loss = -log_logits[node_idx, prediction[node_idx]] else: loss = -log_logits[0, prediction[0]] if self.allow_edge_mask: m = self.edge_mask.sigmoid() edge_reduce = getattr(torch, self.coeffs['edge_reduction']) loss = loss + self.coeffs['edge_size'] * edge_reduce(m) ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) loss = loss + self.coeffs['edge_ent'] * ent.mean() m = self.node_feat_mask.sigmoid() node_feat_reduce = getattr(torch, self.coeffs['node_feat_reduction']) loss = loss + self.coeffs['node_feat_size'] * node_feat_reduce(m) ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) loss = loss + self.coeffs['node_feat_ent'] * ent.mean() return loss
[docs] def explain_graph(self, x, edge_index, **kwargs): r"""Learns and returns a node feature mask and an edge mask that play a crucial role to explain the prediction made by the GNN for a graph. Args: x (Tensor): The node feature matrix. edge_index (LongTensor): The edge indices. **kwargs (optional): Additional arguments passed to the GNN module. :rtype: (:class:`Tensor`, :class:`Tensor`) """ self.model.eval() self._clear_masks() # all nodes belong to same graph batch = torch.zeros(x.shape[0], dtype=int, device=x.device) # Get the initial prediction. prediction = self.get_initial_prediction(x, edge_index, batch=batch, **kwargs) self._initialize_masks(x, edge_index) self.to(x.device) if self.allow_edge_mask: set_masks(self.model, self.edge_mask, edge_index, apply_sigmoid=True) parameters = [self.node_feat_mask, self.edge_mask] else: parameters = [self.node_feat_mask] optimizer = torch.optim.Adam(parameters, lr=self.lr) if self.log: # pragma: no cover pbar = tqdm(total=self.epochs) pbar.set_description('Explain graph') for epoch in range(1, self.epochs + 1): optimizer.zero_grad() h = x * self.node_feat_mask.sigmoid() out = self.model(x=h, edge_index=edge_index, batch=batch, **kwargs) loss = self.get_loss(out, prediction, None) loss.backward() optimizer.step() if self.log: # pragma: no cover pbar.update(1) if self.log: # pragma: no cover pbar.close() node_feat_mask = self.node_feat_mask.detach().sigmoid().squeeze() if self.allow_edge_mask: edge_mask = self.edge_mask.detach().sigmoid() else: edge_mask = torch.ones(edge_index.size(1)) self._clear_masks() return node_feat_mask, edge_mask
[docs] def explain_node(self, node_idx, x, edge_index, **kwargs): r"""Learns and returns a node feature mask and an edge mask that play a crucial role to explain the prediction made by the GNN for node :attr:`node_idx`. Args: node_idx (int): The node to explain. x (Tensor): The node feature matrix. edge_index (LongTensor): The edge indices. **kwargs (optional): Additional arguments passed to the GNN module. :rtype: (:class:`Tensor`, :class:`Tensor`) """ self.model.eval() self._clear_masks() num_nodes = x.size(0) num_edges = edge_index.size(1) # Only operate on a k-hop subgraph around `node_idx`. x, edge_index, mapping, hard_edge_mask, subset, kwargs = \ self.subgraph(node_idx, x, edge_index, **kwargs) # Get the initial prediction. prediction = self.get_initial_prediction(x, edge_index, **kwargs) self._initialize_masks(x, edge_index) self.to(x.device) if self.allow_edge_mask: set_masks(self.model, self.edge_mask, edge_index, apply_sigmoid=True) parameters = [self.node_feat_mask, self.edge_mask] else: parameters = [self.node_feat_mask] optimizer = torch.optim.Adam(parameters, lr=self.lr) if self.log: # pragma: no cover pbar = tqdm(total=self.epochs) pbar.set_description(f'Explain node {node_idx}') for epoch in range(1, self.epochs + 1): optimizer.zero_grad() h = x * self.node_feat_mask.sigmoid() out = self.model(x=h, edge_index=edge_index, **kwargs) loss = self.get_loss(out, prediction, mapping) loss.backward() optimizer.step() if self.log: # pragma: no cover pbar.update(1) if self.log: # pragma: no cover pbar.close() node_feat_mask = self.node_feat_mask.detach().sigmoid() if self.feat_mask_type == 'individual_feature': new_mask = x.new_zeros(num_nodes, x.size(-1)) new_mask[subset] = node_feat_mask node_feat_mask = new_mask elif self.feat_mask_type == 'scalar': new_mask = x.new_zeros(num_nodes, 1) new_mask[subset] = node_feat_mask node_feat_mask = new_mask node_feat_mask = node_feat_mask.squeeze() if self.allow_edge_mask: edge_mask = self.edge_mask.new_zeros(num_edges) edge_mask[hard_edge_mask] = self.edge_mask.detach().sigmoid() else: edge_mask = torch.zeros(num_edges) edge_mask[hard_edge_mask] = 1 self._clear_masks() return node_feat_mask, edge_mask
def __repr__(self): return f'{self.__class__.__name__}()'