Source code for torch_geometric.explain.algorithm.graphmask_explainer

import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import LayerNorm, Linear, Parameter, ReLU
from tqdm import tqdm

from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel
from torch_geometric.nn import MessagePassing

def explain_message(self, out: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor:
    basis_messages = F.layer_norm(out, (out.size(-1), )).relu()

    if getattr(self, 'message_scale', None) is not None:
        basis_messages = basis_messages * self.message_scale.unsqueeze(-1)

        if self.message_replacement is not None:
            if basis_messages.shape == self.message_replacement.shape:
                basis_messages = (basis_messages +
                                  (1 - self.message_scale).unsqueeze(-1) *
                basis_messages = (basis_messages +
                                  ((1 - self.message_scale).unsqueeze(-1) *

    self.latest_messages = basis_messages
    self.latest_source_embeddings = x_j
    self.latest_target_embeddings = x_i

    return basis_messages

[docs]class GraphMaskExplainer(ExplainerAlgorithm): r"""The GraphMask-Explainer model from the `"Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking" <>`_ paper for identifying layer-wise compact subgraph structures and node features that play a crucial role in the predictions made by a GNN. .. note:: For an example of using :class:`GraphMaskExplainer`, see `examples/explain/ < /explain/>`_. A working real-time example of :class:`GraphMaskExplainer` in the form of a deployed app can be accessed `here <>`_. Args: num_layers (int): The number of layers to use. epochs (int, optional): The number of epochs to train. (default: :obj:`100`) lr (float, optional): The learning rate to apply. (default: :obj:`0.01`) penalty_scaling (int, optional): Scaling value of penalty term. Value must lie between 0 and 10. (default: :obj:`5`) lambda_optimizer_lr (float, optional): The learning rate to optimize the Lagrange multiplier. (default: :obj:`1e-2`) init_lambda (float, optional): The Lagrange multiplier. Value must lie between :obj:`0` and `1`. (default: :obj:`0.55`) allowance (float, optional): A float value between :obj:`0` and :obj:`1` denotes tolerance level. (default: :obj:`0.03`) log (bool, optional): If set to :obj:`False`, will not log any learning progress. (default: :obj:`True`) **kwargs (optional): Additional hyper-parameters to override default settings in :attr:`~torch_geometric.nn.models.GraphMaskExplainer.coeffs`. """ coeffs = { 'node_feat_size': 1.0, 'node_feat_reduction': 'mean', 'node_feat_ent': 0.1, 'EPS': 1e-15, } def __init__( self, num_layers: int, epochs: int = 100, lr: float = 0.01, penalty_scaling: int = 5, lambda_optimizer_lr: int = 1e-2, init_lambda: int = 0.55, allowance: int = 0.03, allow_multiple_explanations: bool = False, log: bool = True, **kwargs, ): super().__init__() assert 0 <= penalty_scaling <= 10 assert 0 <= init_lambda <= 1 assert 0 <= allowance <= 1 self.num_layers = num_layers self.init_lambda = init_lambda self.lambda_optimizer_lr = lambda_optimizer_lr self.penalty_scaling = penalty_scaling self.allowance = allowance self.allow_multiple_explanations = allow_multiple_explanations self.epochs = epochs = lr self.log = log self.coeffs.update(kwargs)
[docs] def forward( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Explanation: hard_node_mask = None if self.model_config.task_level == ModelTaskLevel.node: hard_node_mask, hard_edge_mask = self._get_hard_masks( model, index, edge_index, num_nodes=x.size(0)) self._train_explainer(model, x, edge_index, target=target, index=index, **kwargs) node_mask = self._post_process_mask(self.node_feat_mask, hard_node_mask, apply_sigmoid=True) edge_mask = self._explain(model, index=index) edge_mask = edge_mask[:edge_index.size(1)] return Explanation(node_mask=node_mask, edge_mask=edge_mask)
[docs] def supports(self) -> bool: return True
def _hard_concrete( self, input_element: Tensor, summarize_penalty: bool = True, beta: float = 1 / 3, gamma: float = -0.2, zeta: float = 1.2, loc_bias: int = 2, min_val: int = 0, max_val: int = 1, training: bool = True, ) -> Tuple[Tensor, Tensor]: r"""Helps to set the edge mask while sampling its values from the hard-concrete distribution. """ input_element = input_element + loc_bias if training: u = torch.empty_like(input_element).uniform_(1e-6, 1.0 - 1e-6) s = torch.sigmoid( (torch.log(u) - torch.log(1 - u) + input_element) / beta) penalty = torch.sigmoid(input_element - beta * math.log(-gamma / zeta)) else: s = torch.sigmoid(input_element) penalty = torch.zeros_like(input_element) if summarize_penalty: penalty = penalty.mean() s = s * (zeta - gamma) + gamma clipped_s = s.clamp(min_val, max_val) clip_value = (torch.min(clipped_s) + torch.max(clipped_s)) / 2 hard_concrete = (clipped_s > clip_value).float() clipped_s = clipped_s + (hard_concrete - clipped_s).detach() return clipped_s, penalty def _set_masks( self, i_dim: List[int], j_dim: List[int], h_dim: List[int], x: Tensor, ): r"""Sets the node masks and edge masks.""" (num_nodes, num_feat), std, device = x.size(), 0.1, x.device self.feat_mask_type = self.explainer_config.node_mask_type if self.feat_mask_type == MaskType.attributes: self.node_feat_mask = torch.nn.Parameter( torch.randn(num_nodes, num_feat, device=device) * std) elif self.feat_mask_type == MaskType.object: self.node_feat_mask = torch.nn.Parameter( torch.randn(num_nodes, 1, device=device) * std) else: self.node_feat_mask = torch.nn.Parameter( torch.randn(1, num_feat, device=device) * std) baselines, self.gates, full_biases = [], torch.nn.ModuleList(), [] for v_dim, m_dim, h_dim in zip(i_dim, j_dim, h_dim): self.transform, self.layer_norm = [], [] input_dims = [v_dim, m_dim, v_dim] for _, input_dim in enumerate(input_dims): self.transform.append( Linear(input_dim, h_dim, bias=False).to(device)) self.layer_norm.append(LayerNorm(h_dim).to(device)) self.transforms = torch.nn.ModuleList(self.transform) self.layer_norms = torch.nn.ModuleList(self.layer_norm) self.full_bias = Parameter( torch.tensor(h_dim, dtype=torch.float, device=device)) full_biases.append(self.full_bias) self.reset_parameters(input_dims, h_dim) self.non_linear = ReLU() self.output_layer = Linear(h_dim, 1).to(device) gate = [ self.transforms, self.layer_norms, self.non_linear, self.output_layer ] self.gates.extend(gate) baseline = torch.tensor(m_dim, dtype=torch.float, device=device) stdv = 1. / math.sqrt(m_dim) baseline.uniform_(-stdv, stdv) baseline = torch.nn.Parameter(baseline) baselines.append(baseline) full_biases = torch.nn.ParameterList(full_biases) self.full_biases = full_biases baselines = torch.nn.ParameterList(baselines) self.baselines = baselines for parameter in self.parameters(): parameter.requires_grad = False def _enable_layer(self, layer: int): r"""Enables the input layer's edge mask.""" for d in range(layer * 4, (layer * 4) + 4): for parameter in self.gates[d].parameters(): parameter.requires_grad = True self.full_biases[layer].requires_grad = True self.baselines[layer].requires_grad = True
[docs] def reset_parameters(self, input_dims: List[int], h_dim: List[int]): r"""Resets all learnable parameters of the module.""" fan_in = sum(input_dims) std = math.sqrt(2.0 / float(fan_in + h_dim)) a = math.sqrt(3.0) * std for transform in self.transforms: torch.nn.init._no_grad_uniform_(transform.weight, -a, a) torch.nn.init.zeros_(self.full_bias) for layer_norm in self.layer_norms: layer_norm.reset_parameters()
def _loss(self, y_hat: Tensor, y: Tensor, penalty: float) -> Tensor: if self.model_config.mode == ModelMode.binary_classification: loss = self._loss_binary_classification(y_hat, y) elif self.model_config.mode == ModelMode.multiclass_classification: loss = self._loss_multiclass_classification(y_hat, y) elif self.model_config.mode == ModelMode.regression: loss = self._loss_regression(y_hat, y) else: assert False g = torch.relu(loss - self.allowance).mean() f = penalty * self.penalty_scaling loss = f + F.softplus(self.lambda_op) * g m = self.node_feat_mask.sigmoid() node_feat_reduce = getattr(torch, self.coeffs['node_feat_reduction']) loss = loss + self.coeffs['node_feat_size'] * node_feat_reduce(m) ent = -m * torch.log(m + self.coeffs['EPS']) - ( 1 - m) * torch.log(1 - m + self.coeffs['EPS']) loss = loss + self.coeffs['node_feat_ent'] * ent.mean() return loss def _freeze_model(self, module: torch.nn.Module): r"""Freezes the parameters of the original GNN model by disabling their gradients. """ for param in module.parameters(): param.requires_grad = False def _set_flags(self, model: torch.nn.Module): r"""Initializes the underlying explainer model's parameters for each layer of the original GNN model. """ for module in model.modules(): if isinstance(module, MessagePassing): module.explain_message = explain_message.__get__( module, MessagePassing) module.explain = True def _inject_messages( self, model: torch.nn.Module, message_scale: List[Tensor], message_replacement: torch.nn.ParameterList, set: bool = False, ): r"""Injects the computed messages into each layer of the original GNN model. """ i = 0 for module in model.modules(): if isinstance(module, MessagePassing): if not set: module.message_scale = message_scale[i] module.message_replacement = message_replacement[i] i = i + 1 else: module.message_scale = None module.message_replacement = None def _train_explainer( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ): r"""Trains the underlying explainer model. Args: model (torch.nn.Module): The model to explain. x (torch.Tensor): The input node features. edge_index (torch.Tensor): The input edge indices. target (torch.Tensor): The target of the model. index (int or torch.Tensor, optional): The index of the model output to explain. Needs to be a single index. (default: :obj:`None`) **kwargs (optional): Additional keyword arguments passed to :obj:`model`. """ if (not isinstance(index, Tensor) and not isinstance(index, int) and index is not None): raise ValueError("'index' parameter can only be a 'Tensor', " "'integer' or set to 'None' instead.") self._freeze_model(model) self._set_flags(model) input_dims, output_dims = [], [] for module in model.modules(): if isinstance(module, MessagePassing): input_dims.append(module.in_channels) output_dims.append(module.out_channels) self._set_masks(input_dims, output_dims, output_dims, x) optimizer = torch.optim.Adam(self.parameters(), for layer in reversed(list(range(self.num_layers))): if self.log: pbar = tqdm(total=self.epochs) if self.model_config.task_level == ModelTaskLevel.node: pbar.set_description( f'Train explainer for node(s) {index} with layer ' f'{layer}') elif self.model_config.task_level == ModelTaskLevel.edge: pbar.set_description( f"Train explainer for edge-level task with layer " f"{layer}") else: pbar.set_description( f'Train explainer for graph {index} with layer ' f'{layer}') self._enable_layer(layer) for epoch in range(self.epochs): with torch.no_grad(): model(x, edge_index, **kwargs) gates, total_penalty = [], 0 latest_source_embeddings, latest_messages = [], [] latest_target_embeddings = [] for module in model.modules(): if isinstance(module, MessagePassing): latest_source_embeddings.append( module.latest_source_embeddings) latest_messages.append(module.latest_messages) latest_target_embeddings.append( module.latest_target_embeddings) gate_input = [ latest_source_embeddings, latest_messages, latest_target_embeddings ] for i in range(self.num_layers): output = self.full_biases[i] for j in range(len(gate_input)): try: partial = self.gates[i * 4][j](gate_input[j][i]) except Exception: try: self._set_masks(output_dims, output_dims, output_dims, x) partial = self.gates[i * 4][j]( gate_input[j][i]) except Exception: self._set_masks(input_dims, input_dims, output_dims, x) partial = self.gates[i * 4][j]( gate_input[j][i]) result = self.gates[(i * 4) + 1][j](partial) output = output + result relu_output = self.gates[(i * 4) + 2](output / len(gate_input)) sampling_weights = self.gates[(i * 4) + 3](relu_output).squeeze( dim=-1) sampling_weights, penalty = self._hard_concrete( sampling_weights) gates.append(sampling_weights) total_penalty += penalty self._inject_messages(model, gates, self.baselines) self.lambda_op = torch.tensor(self.init_lambda, requires_grad=True) optimizer_lambda = torch.optim.RMSprop( [self.lambda_op], lr=self.lambda_optimizer_lr, centered=True) optimizer.zero_grad() optimizer_lambda.zero_grad() h = x * self.node_feat_mask.sigmoid() y_hat, y = model(x=h, edge_index=edge_index, **kwargs), target if (self.model_config.task_level == ModelTaskLevel.node or self.model_config.task_level == ModelTaskLevel.edge): if index is not None: y_hat, y = y_hat[index], y[index] self._inject_messages(model, gates, self.baselines, True) loss = self._loss(y_hat, y, total_penalty) loss.backward() optimizer.step() self.lambda_op.grad *= -1 optimizer_lambda.step() if self.lambda_op.item() < -2: = torch.full_like(, -2) elif self.lambda_op.item() > 30: = torch.full_like(, 30) if self.log: pbar.update(1) if self.log: pbar.close() def _explain( self, model: torch.nn.Module, *, index: Optional[Union[int, Tensor]] = None, ) -> Tensor: r"""Generates explanations for the original GNN model. Args: model (torch.nn.Module): The model to explain. index (int or torch.Tensor, optional): The index of the model output to explain. Needs to be a single index. (default: :obj:`None`). """ if (not isinstance(index, Tensor) and not isinstance(index, int) and index is not None): raise ValueError("'index' parameter can only be a 'Tensor', " "'integer' or set to 'None' instead.") self._freeze_model(model) self._set_flags(model) with torch.no_grad(): latest_source_embeddings, latest_messages = [], [] latest_target_embeddings = [] for module in model.modules(): if isinstance(module, MessagePassing): latest_source_embeddings.append( module.latest_source_embeddings) latest_messages.append(module.latest_messages) latest_target_embeddings.append( module.latest_target_embeddings) gate_input = [ latest_source_embeddings, latest_messages, latest_target_embeddings ] if self.log: pbar = tqdm(total=self.num_layers) for i in range(self.num_layers): if self.log: pbar.set_description("Explain") output = self.full_biases[i] for j in range(len(gate_input)): partial = self.gates[i * 4][j](gate_input[j][i]) result = self.gates[(i * 4) + 1][j](partial) output = output + result relu_output = self.gates[(i * 4) + 2](output / len(gate_input)) sampling_weights = self.gates[(i * 4) + 3](relu_output).squeeze(dim=-1) sampling_weights, _ = self._hard_concrete( sampling_weights, training=False) if i == 0: edge_weight = sampling_weights else: edge_weight =, sampling_weights), 0) if self.log: pbar.update(1) if self.log: pbar.close() edge_mask = edge_weight.view(-1, edge_weight.size(0) // self.num_layers) edge_mask = torch.mean(edge_mask, 0) return edge_mask