from math import sqrt
from typing import Optional
import torch
from tqdm import tqdm
from torch_geometric.nn.models.explainer import (
Explainer,
clear_masks,
set_masks,
)
EPS = 1e-15
[docs]class GNNExplainer(Explainer):
r"""The GNN-Explainer model from the `"GNNExplainer: Generating
Explanations for Graph Neural Networks"
<https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph
structures and small subsets node features that play a crucial role in a
GNN’s node-predictions.
.. note::
For an example of using GNN-Explainer, see `examples/gnn_explainer.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
gnn_explainer.py>`_.
Args:
model (torch.nn.Module): The GNN module to explain.
epochs (int, optional): The number of epochs to train.
(default: :obj:`100`)
lr (float, optional): The learning rate to apply.
(default: :obj:`0.01`)
num_hops (int, optional): The number of hops the :obj:`model` is
aggregating information from.
If set to :obj:`None`, will automatically try to detect this
information based on the number of
:class:`~torch_geometric.nn.conv.message_passing.MessagePassing`
layers inside :obj:`model`. (default: :obj:`None`)
return_type (str, optional): Denotes the type of output from
:obj:`model`. Valid inputs are :obj:`"log_prob"` (the model
returns the logarithm of probabilities), :obj:`"prob"` (the
model returns probabilities), :obj:`"raw"` (the model returns raw
scores) and :obj:`"regression"` (the model returns scalars).
(default: :obj:`"log_prob"`)
feat_mask_type (str, optional): Denotes the type of feature mask
that will be learned. Valid inputs are :obj:`"feature"` (a single
feature-level mask for all nodes), :obj:`"individual_feature"`
(individual feature-level masks for each node), and :obj:`"scalar"`
(scalar mask for each each node). (default: :obj:`"feature"`)
allow_edge_mask (boolean, optional): If set to :obj:`False`, the edge
mask will not be optimized. (default: :obj:`True`)
log (bool, optional): If set to :obj:`False`, will not log any learning
progress. (default: :obj:`True`)
**kwargs (optional): Additional hyper-parameters to override default
settings in :attr:`~torch_geometric.nn.models.GNNExplainer.coeffs`.
"""
coeffs = {
'edge_size': 0.005,
'edge_reduction': 'sum',
'node_feat_size': 1.0,
'node_feat_reduction': 'mean',
'edge_ent': 1.0,
'node_feat_ent': 0.1,
}
def __init__(self, model, epochs: int = 100, lr: float = 0.01,
num_hops: Optional[int] = None, return_type: str = 'log_prob',
feat_mask_type: str = 'feature', allow_edge_mask: bool = True,
log: bool = True, **kwargs):
super().__init__(model, lr, epochs, num_hops, return_type, log)
assert feat_mask_type in ['feature', 'individual_feature', 'scalar']
self.allow_edge_mask = allow_edge_mask
self.feat_mask_type = feat_mask_type
self.coeffs.update(kwargs)
def _initialize_masks(self, x, edge_index, init="normal"):
(N, F), E = x.size(), edge_index.size(1)
std = 0.1
if self.feat_mask_type == 'individual_feature':
self.node_feat_mask = torch.nn.Parameter(torch.randn(N, F) * std)
elif self.feat_mask_type == 'scalar':
self.node_feat_mask = torch.nn.Parameter(torch.randn(N, 1) * std)
else:
self.node_feat_mask = torch.nn.Parameter(torch.randn(1, F) * std)
std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
if self.allow_edge_mask:
self.edge_mask = torch.nn.Parameter(torch.randn(E) * std)
def _clear_masks(self):
clear_masks(self.model)
self.node_feat_masks = None
self.edge_mask = None
def _loss(self, log_logits, prediction, node_idx: Optional[int] = None):
if self.return_type == 'regression':
if node_idx is not None and node_idx >= 0:
loss = torch.cdist(log_logits[node_idx], prediction[node_idx])
else:
loss = torch.cdist(log_logits, prediction)
else:
if node_idx is not None and node_idx >= 0:
loss = -log_logits[node_idx, prediction[node_idx]]
else:
loss = -log_logits[0, prediction[0]]
if self.allow_edge_mask:
m = self.edge_mask.sigmoid()
edge_reduce = getattr(torch, self.coeffs['edge_reduction'])
loss = loss + self.coeffs['edge_size'] * edge_reduce(m)
ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
loss = loss + self.coeffs['edge_ent'] * ent.mean()
m = self.node_feat_mask.sigmoid()
node_feat_reduce = getattr(torch, self.coeffs['node_feat_reduction'])
loss = loss + self.coeffs['node_feat_size'] * node_feat_reduce(m)
ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
loss = loss + self.coeffs['node_feat_ent'] * ent.mean()
return loss
[docs] def explain_graph(self, x, edge_index, **kwargs):
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for a graph.
Args:
x (Tensor): The node feature matrix.
edge_index (LongTensor): The edge indices.
**kwargs (optional): Additional arguments passed to the GNN module.
:rtype: (:class:`Tensor`, :class:`Tensor`)
"""
self.model.eval()
self._clear_masks()
# all nodes belong to same graph
batch = torch.zeros(x.shape[0], dtype=int, device=x.device)
# Get the initial prediction.
prediction = self.get_initial_prediction(x, edge_index, batch=batch,
**kwargs)
self._initialize_masks(x, edge_index)
self.to(x.device)
if self.allow_edge_mask:
set_masks(self.model, self.edge_mask, edge_index,
apply_sigmoid=True)
parameters = [self.node_feat_mask, self.edge_mask]
else:
parameters = [self.node_feat_mask]
optimizer = torch.optim.Adam(parameters, lr=self.lr)
if self.log: # pragma: no cover
pbar = tqdm(total=self.epochs)
pbar.set_description('Explain graph')
for epoch in range(1, self.epochs + 1):
optimizer.zero_grad()
h = x * self.node_feat_mask.sigmoid()
out = self.model(x=h, edge_index=edge_index, batch=batch, **kwargs)
loss = self.get_loss(out, prediction, None)
loss.backward()
optimizer.step()
if self.log: # pragma: no cover
pbar.update(1)
if self.log: # pragma: no cover
pbar.close()
node_feat_mask = self.node_feat_mask.detach().sigmoid().squeeze()
if self.allow_edge_mask:
edge_mask = self.edge_mask.detach().sigmoid()
else:
edge_mask = torch.ones(edge_index.size(1))
self._clear_masks()
return node_feat_mask, edge_mask
[docs] def explain_node(self, node_idx, x, edge_index, **kwargs):
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for node
:attr:`node_idx`.
Args:
node_idx (int): The node to explain.
x (Tensor): The node feature matrix.
edge_index (LongTensor): The edge indices.
**kwargs (optional): Additional arguments passed to the GNN module.
:rtype: (:class:`Tensor`, :class:`Tensor`)
"""
self.model.eval()
self._clear_masks()
num_nodes = x.size(0)
num_edges = edge_index.size(1)
# Only operate on a k-hop subgraph around `node_idx`.
x, edge_index, mapping, hard_edge_mask, subset, kwargs = \
self.subgraph(node_idx, x, edge_index, **kwargs)
# Get the initial prediction.
prediction = self.get_initial_prediction(x, edge_index, **kwargs)
self._initialize_masks(x, edge_index)
self.to(x.device)
if self.allow_edge_mask:
set_masks(self.model, self.edge_mask, edge_index,
apply_sigmoid=True)
parameters = [self.node_feat_mask, self.edge_mask]
else:
parameters = [self.node_feat_mask]
optimizer = torch.optim.Adam(parameters, lr=self.lr)
if self.log: # pragma: no cover
pbar = tqdm(total=self.epochs)
pbar.set_description(f'Explain node {node_idx}')
for epoch in range(1, self.epochs + 1):
optimizer.zero_grad()
h = x * self.node_feat_mask.sigmoid()
out = self.model(x=h, edge_index=edge_index, **kwargs)
loss = self.get_loss(out, prediction, mapping)
loss.backward()
optimizer.step()
if self.log: # pragma: no cover
pbar.update(1)
if self.log: # pragma: no cover
pbar.close()
node_feat_mask = self.node_feat_mask.detach().sigmoid()
if self.feat_mask_type == 'individual_feature':
new_mask = x.new_zeros(num_nodes, x.size(-1))
new_mask[subset] = node_feat_mask
node_feat_mask = new_mask
elif self.feat_mask_type == 'scalar':
new_mask = x.new_zeros(num_nodes, 1)
new_mask[subset] = node_feat_mask
node_feat_mask = new_mask
node_feat_mask = node_feat_mask.squeeze()
if self.allow_edge_mask:
edge_mask = self.edge_mask.new_zeros(num_edges)
edge_mask[hard_edge_mask] = self.edge_mask.detach().sigmoid()
else:
edge_mask = torch.zeros(num_edges)
edge_mask[hard_edge_mask] = 1
self._clear_masks()
return node_feat_mask, edge_mask
def __repr__(self):
return f'{self.__class__.__name__}()'