Source code for torch_geometric.contrib.explain.graphmask_explainer

import math
from typing import Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, sigmoid
from torch.nn import LayerNorm, Linear, Parameter, ReLU, Sequential, init
from tqdm import tqdm

from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.config import (
    MaskType,
    ModelMode,
    ModelReturnType,
    ModelTaskLevel,
)
from torch_geometric.nn import MessagePassing


def explain_message(self, out, x_i, x_j):
    norm = Sequential(LayerNorm(out.size(-1)).to(out.device), ReLU())
    basis_messages = norm(out)

    if getattr(self, 'message_scale', None) is not None:
        basis_messages = basis_messages * self.message_scale.unsqueeze(-1)

        if self.message_replacement is not None:
            if basis_messages.shape == self.message_replacement.shape:
                basis_messages = (basis_messages +
                                  (1 - self.message_scale).unsqueeze(-1) *
                                  self.message_replacement)
            else:
                basis_messages = (basis_messages +
                                  ((1 - self.message_scale).unsqueeze(-1) *
                                   self.message_replacement.unsqueeze(0)))

    self.latest_messages = basis_messages
    self.latest_source_embeddings = x_j
    self.latest_target_embeddings = x_i

    return basis_messages


[docs]class GraphMaskExplainer(ExplainerAlgorithm): r"""The GraphMask-Explainer model from the `"Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking" <https://arxiv.org/abs/2010.00577>`_ paper for identifying layer-wise compact subgraph structures and node features that play a crucial role in the predictions made by a GNN. .. note:: For an example of using :class:`GraphMaskExplainer`, see `examples/contrib/graphmask_explainer.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ /contrib/graphmask_explainer.py>`_. Args: num_layers (int): The number of layers to use. epochs (int, optional): The number of epochs to train. (default: :obj:`100`) lr (float, optional): The learning rate to apply. (default: :obj:`0.01`) penalty_scaling (int, optional): Scaling value of penalty term. Value must lie between 0 and 10. (default: :obj:`5`) lambda_optimizer_lr (float, optional): The learning rate to optimize the Lagrange multiplier. (default: :obj:`1e-2`) init_lambda (float, optional): The Lagrange multiplier. Value must lie between :obj:`0` and `1`. (default: :obj:`0.55`) allowance (float, optional): A float value between :obj:`0` and :obj:`1` denotes tolerance level. (default: :obj:`0.03`) layer_type (str, optional): The type of GNN layer being used in the GNN model. (default: :obj:`GCN`) log (bool, optional): If set to :obj:`False`, will not log any learning progress. (default: :obj:`True`) **kwargs (optional): Additional hyper-parameters to override default settings in :attr:`~torch_geometric.nn.models.GraphMaskExplainer.coeffs`. """ coeffs = { 'node_feat_size': 1.0, 'node_feat_reduction': 'mean', 'node_feat_ent': 0.1, 'EPS': 1e-15, } def __init__( self, num_layers: int, epochs: int = 100, lr: float = 0.01, penalty_scaling: int = 5, lambda_optimizer_lr: int = 1e-2, init_lambda: int = 0.55, allowance: int = 0.03, layer_type: str = 'GCN', allow_multiple_explanations: bool = False, log: bool = True, **kwargs, ): super().__init__() assert layer_type in ['GCN', 'GAT', 'FastRGCN'] assert 0 <= penalty_scaling <= 10 assert 0 <= init_lambda <= 1 assert 0 <= allowance <= 1 self.num_layers = num_layers self.init_lambda = init_lambda self.lambda_optimizer_lr = lambda_optimizer_lr self.penalty_scaling = penalty_scaling self.allowance = allowance self.layer_type = layer_type self.allow_multiple_explanations = allow_multiple_explanations self.epochs = epochs self.lr = lr self.log = log self.coeffs.update(kwargs)
[docs] def forward( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Explanation: hard_node_mask = None if self.model_config.task_level == ModelTaskLevel.node: hard_node_mask, hard_edge_mask = self._get_hard_masks( model, index, edge_index, num_nodes=x.size(0)) self.train_explainer(model, x, edge_index, target=target, index=index, **kwargs) node_mask = self._post_process_mask(self.node_feat_mask, hard_node_mask, apply_sigmoid=True) edge_mask = self.explain(model, index=index) edge_mask = edge_mask[:edge_index.size(1)] return Explanation(node_mask=node_mask, edge_mask=edge_mask)
[docs] def supports(self) -> bool: return True
def hard_concrete(self, input_element, summarize_penalty=True, beta=1 / 3, gamma=-0.2, zeta=1.2, loc_bias=2, min_val=0, max_val=1, training=True) -> Union[Tensor, Tensor]: input_element = input_element + loc_bias if training: u = torch.empty_like(input_element).uniform_(1e-6, 1.0 - 1e-6) s = sigmoid( (torch.log(u) - torch.log(1 - u) + input_element) / beta) penalty = sigmoid(input_element - beta * np.math.log(-gamma / zeta)) else: s = sigmoid(input_element) penalty = torch.zeros_like(input_element) if summarize_penalty: penalty = penalty.mean() s = s * (zeta - gamma) + gamma clipped_s = s.clamp(min_val, max_val) clip_value = (torch.min(clipped_s) + torch.max(clipped_s)) / 2 hard_concrete = (clipped_s > clip_value).float() clipped_s = clipped_s + (hard_concrete - clipped_s).detach() return clipped_s, penalty def set_masks(self, i_dim, j_dim, h_dim, x, device): if self.layer_type == 'GCN' or self.layer_type == 'GAT': i_dim = j_dim (num_nodes, num_feat), std = x.size(), 0.1 self.feat_mask_type = self.explainer_config.node_mask_type if self.feat_mask_type == MaskType.attributes: self.node_feat_mask = torch.nn.Parameter( torch.randn(num_nodes, num_feat, device=device) * std) elif self.feat_mask_type == MaskType.object: self.node_feat_mask = torch.nn.Parameter( torch.randn(num_nodes, 1, device=device) * std) else: self.node_feat_mask = torch.nn.Parameter( torch.randn(1, num_feat, device=device) * std) baselines, self.gates, full_biases = [], torch.nn.ModuleList(), [] for v_dim, m_dim, h_dim in zip(i_dim, j_dim, h_dim): self.transform, self.layer_norm = [], [] input_dims = [v_dim, m_dim, v_dim] for _, input_dim in enumerate(input_dims): self.transform.append( Linear(input_dim, h_dim, bias=False).to(device)) self.layer_norm.append(LayerNorm(h_dim).to(device)) self.transforms = torch.nn.ModuleList(self.transform) self.layer_norms = torch.nn.ModuleList(self.layer_norm) self.full_bias = Parameter( torch.tensor(h_dim, dtype=torch.float, device=device)) full_biases.append(self.full_bias) self.reset_parameters(input_dims, h_dim) self.non_linear = ReLU() self.output_layer = Linear(h_dim, 1).to(device) gate = [ self.transforms, self.layer_norms, self.non_linear, self.output_layer ] self.gates.extend(gate) baseline = torch.tensor(m_dim, dtype=torch.float, device=device) stdv = 1. / math.sqrt(m_dim) baseline.uniform_(-stdv, stdv) baseline = torch.nn.Parameter(baseline) baselines.append(baseline) full_biases = torch.nn.ParameterList(full_biases) self.full_biases = full_biases baselines = torch.nn.ParameterList(baselines) self.baselines = baselines for parameter in self.parameters(): parameter.requires_grad = False def enable_layer(self, layer): for d in range(layer * 4, (layer * 4) + 4): for parameter in self.gates[d].parameters(): parameter.requires_grad = True self.full_biases[layer].requires_grad = True self.baselines[layer].requires_grad = True def reset_parameters(self, input_dims, h_dim): fan_in = sum(input_dims) std = math.sqrt(2.0 / float(fan_in + h_dim)) a = math.sqrt(3.0) * std for transform in self.transforms: init._no_grad_uniform_(transform.weight, -a, a) init.zeros_(self.full_bias) for layer_norm in self.layer_norms: layer_norm.reset_parameters() def _loss_regression(self, y_hat: Tensor, y: Tensor) -> Tensor: assert self.model_config.return_type == ModelReturnType.raw return F.mse_loss(y_hat, y) def _loss_binary_classification(self, y_hat: Tensor, y: Tensor) -> Tensor: if self.model_config.return_type == ModelReturnType.raw: loss_fn = F.binary_cross_entropy_with_logits elif self.model_config.return_type == ModelReturnType.probs: loss_fn = F.binary_cross_entropy else: assert False return loss_fn(y_hat.view_as(y), y.float()) def _loss_multiclass_classification( self, y_hat: Tensor, y: Tensor, ) -> Tensor: if self.model_config.return_type == ModelReturnType.raw: loss_fn = F.cross_entropy elif self.model_config.return_type == ModelReturnType.probs: loss_fn = F.nll_loss y_hat = y_hat.log() elif self.model_config.return_type == ModelReturnType.log_probs: loss_fn = F.nll_loss else: assert False return loss_fn(y_hat, y) def _loss(self, y_hat: Tensor, y: Tensor, penalty) -> Tensor: if self.model_config.mode == ModelMode.binary_classification: loss = self._loss_binary_classification(y_hat, y) elif self.model_config.mode == ModelMode.multiclass_classification: loss = self._loss_multiclass_classification(y_hat, y) elif self.model_config.mode == ModelMode.regression: loss = self._loss_regression(y_hat, y) else: assert False g = torch.relu(loss - self.allowance).mean() f = penalty * self.penalty_scaling loss = f + F.softplus(self.lambda_op) * g m = self.node_feat_mask.sigmoid() node_feat_reduce = getattr(torch, self.coeffs['node_feat_reduction']) loss = loss + self.coeffs['node_feat_size'] * node_feat_reduce(m) ent = -m * torch.log(m + self.coeffs['EPS']) - ( 1 - m) * torch.log(1 - m + self.coeffs['EPS']) loss = loss + self.coeffs['node_feat_ent'] * ent.mean() return loss def freeze_model(self, module): for param in module.parameters(): param.requires_grad = False def _set_flags(self, model): for module in model.modules(): if isinstance(module, MessagePassing): module.explain_message = explain_message.__get__( module, MessagePassing) module.explain = True def _inject_messages(self, model: torch.nn.Module, message_scale, message_replacement, set=False): i = 0 for module in model.modules(): if isinstance(module, MessagePassing): if not set: module.message_scale = message_scale[i] module.message_replacement = message_replacement[i] i = i + 1 else: module.message_scale = None module.message_replacement = None def train_explainer(self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs): if not isinstance(index, Tensor) and not isinstance(index, int) \ and index is not None: raise ValueError("'index' parameter can only be a 'Tensor', " "'integer' or set to 'None' instead.") self.freeze_model(model) self._set_flags(model) input_dims, output_dims = [], [] for module in model.modules(): if isinstance(module, MessagePassing): input_dims.append(module.in_channels) output_dims.append(module.out_channels) self.set_masks(input_dims, output_dims, output_dims, x, x.device) optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) for layer in reversed(list(range(self.num_layers))): if self.log: pbar = tqdm(total=self.epochs) if self.model_config.task_level == ModelTaskLevel.node: pbar.set_description( f'Train explainer for node(s) {index} with layer ' f'{layer}') elif self.model_config.task_level == ModelTaskLevel.edge: pbar.set_description( f"Train explainer for edge-level task with layer " f"{layer}") else: pbar.set_description( f'Train explainer for graph {index} with layer ' f'{layer}') self.enable_layer(layer) for epoch in range(self.epochs): with torch.no_grad(): model(x, edge_index, **kwargs) gates, total_penalty = [], 0 latest_source_embeddings, latest_messages = [], [] latest_target_embeddings = [] for module in model.modules(): if isinstance(module, MessagePassing): latest_source_embeddings.append( module.latest_source_embeddings) latest_messages.append(module.latest_messages) latest_target_embeddings.append( module.latest_target_embeddings) gate_input = [ latest_source_embeddings, latest_messages, latest_target_embeddings ] for i in range(self.num_layers): output = self.full_biases[i] for j in range(len(gate_input)): partial = self.gates[i * 4][j](gate_input[j][i]) result = self.gates[(i * 4) + 1][j](partial) output = output + result relu_output = self.gates[(i * 4) + 2](output / len(gate_input)) sampling_weights = self.gates[(i * 4) + 3](relu_output).squeeze( dim=-1) sampling_weights, penalty = self.hard_concrete( sampling_weights) gates.append(sampling_weights) total_penalty += penalty self._inject_messages(model, gates, self.baselines) self.lambda_op = torch.tensor(self.init_lambda, requires_grad=True) optimizer_lambda = torch.optim.RMSprop( [self.lambda_op], lr=self.lambda_optimizer_lr, centered=True) optimizer.zero_grad() optimizer_lambda.zero_grad() h = x * self.node_feat_mask.sigmoid() y_hat, y = model(x=h, edge_index=edge_index, **kwargs), target if self.model_config.task_level == ModelTaskLevel.node \ or self.model_config.task_level == ModelTaskLevel.edge: if index is not None: y_hat, y = y_hat[index], y[index] self._inject_messages(model, gates, self.baselines, True) loss = self._loss(y_hat, y, total_penalty) loss.backward() optimizer.step() self.lambda_op.grad *= -1 optimizer_lambda.step() if self.lambda_op.item() < -2: self.lambda_op.data = torch.full_like( self.lambda_op.data, -2) elif self.lambda_op.item() > 30: self.lambda_op.data = torch.full_like( self.lambda_op.data, 30) if self.log: pbar.update(1) if self.log: pbar.close() def explain(self, model: torch.nn.Module, *, index: Optional[Union[int, Tensor]] = None) -> Tensor: if not isinstance(index, Tensor) and not isinstance(index, int) \ and index is not None: raise ValueError("'index' parameter can only be a 'Tensor', " "'integer' or set to 'None' instead.") self.freeze_model(model) self._set_flags(model) with torch.no_grad(): latest_source_embeddings, latest_messages = [], [] latest_target_embeddings = [] for module in model.modules(): if isinstance(module, MessagePassing): latest_source_embeddings.append( module.latest_source_embeddings) latest_messages.append(module.latest_messages) latest_target_embeddings.append( module.latest_target_embeddings) gate_input = [ latest_source_embeddings, latest_messages, latest_target_embeddings ] if self.log: pbar = tqdm(total=self.num_layers) for i in range(self.num_layers): if self.log: pbar.set_description("Explain") output = self.full_biases[i] for j in range(len(gate_input)): partial = self.gates[i * 4][j](gate_input[j][i]) result = self.gates[(i * 4) + 1][j](partial) output = output + result relu_output = self.gates[(i * 4) + 2](output / len(gate_input)) sampling_weights = self.gates[(i * 4) + 3](relu_output).squeeze(dim=-1) sampling_weights, _ = self.hard_concrete( sampling_weights, training=False) if i == 0: edge_weight = sampling_weights else: if (edge_weight.size(-1) != sampling_weights.size(-1) and self.layer_type == 'GAT'): sampling_weights = F.pad( input=sampling_weights, pad=(0, edge_weight.size(-1) - sampling_weights.size(-1), 0, 0), mode='constant', value=0) edge_weight = torch.cat((edge_weight, sampling_weights), 0) if self.log: pbar.update(1) if self.log: pbar.close() edge_mask = edge_weight.view(-1, edge_weight.size(0) // self.num_layers) edge_mask = torch.mean(edge_mask, 0) return edge_mask def __repr__(self): return f'{self.__class__.__name__}()'