from inspect import signature
from math import sqrt
from typing import Optional
import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import get_num_hops, k_hop_subgraph, to_networkx
def set_masks(model: torch.nn.Module, mask: Tensor, edge_index: Tensor,
apply_sigmoid: bool = True):
"""Apply mask to every graph layer in the model."""
loop_mask = edge_index[0] != edge_index[1]
# Loop over layers and set masks on MessagePassing layers:
for module in model.modules():
if isinstance(module, MessagePassing):
module._explain = True
module._edge_mask = mask
module._loop_mask = loop_mask
module._apply_sigmoid = apply_sigmoid
def clear_masks(model: torch.nn.Module):
"""Clear all masks from the model."""
for module in model.modules():
if isinstance(module, MessagePassing):
module._explain = False
module._edge_mask = None
module._loop_mask = None
module._apply_sigmoid = True
return module
class CaptumModel(torch.nn.Module):
def __init__(self, model: torch.nn.Module, mask_type: str = "edge",
output_idx: Optional[int] = None):
super().__init__()
assert mask_type in ['edge', 'node', 'node_and_edge']
self.mask_type = mask_type
self.model = model
self.output_idx = output_idx
def forward(self, mask, *args):
""""""
# The mask tensor, which comes from Captum's attribution methods,
# contains the number of samples in dimension 0. Since we are
# working with only one sample, we squeeze the tensors below.
assert mask.shape[0] == 1, "Dimension 0 of input should be 1"
if self.mask_type == "edge":
assert len(args) >= 2, "Expects at least x and edge_index as args."
if self.mask_type == "node":
assert len(args) >= 1, "Expects at least edge_index as args."
if self.mask_type == "node_and_edge":
assert args[0].shape[0] == 1, "Dimension 0 of input should be 1"
assert len(args[1:]) >= 1, "Expects at least edge_index as args."
# Set edge mask:
if self.mask_type == 'edge':
set_masks(self.model, mask.squeeze(0), args[1],
apply_sigmoid=False)
elif self.mask_type == 'node_and_edge':
set_masks(self.model, args[0].squeeze(0), args[1],
apply_sigmoid=False)
args = args[1:]
if self.mask_type == 'edge':
x = self.model(*args)
elif self.mask_type == 'node':
x = self.model(mask.squeeze(0), *args)
else:
x = self.model(mask[0], *args)
# Clear mask:
if self.mask_type in ['edge', 'node_and_edge']:
clear_masks(self.model)
if self.output_idx is not None:
x = x[self.output_idx].unsqueeze(0)
return x
[docs]def to_captum(model: torch.nn.Module, mask_type: str = "edge",
output_idx: Optional[int] = None) -> torch.nn.Module:
r"""Converts a model to a model that can be used for
`Captum.ai <https://captum.ai/>`_ attribution methods.
.. code-block:: python
from captum.attr import IntegratedGradients
from torch_geometric.nn import GCN, from_captum
model = GCN(...)
... # Train the model.
# Explain predictions for node `10`:
output_idx = 10
captum_model = to_captum(model, mask_type="edge",
output_idx=output_idx)
edge_mask = torch.ones(num_edges, requires_grad=True, device=device)
ig = IntegratedGradients(captum_model)
ig_attr = ig.attribute(edge_mask.unsqueeze(0),
target=int(y[output_idx]),
additional_forward_args=(x, edge_index),
internal_batch_size=1)
.. note::
For an example of using a Captum attribution method within PyG, see
`examples/captum_explainability.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
captum_explainability.py>`_.
Args:
model (torch.nn.Module): The model to be explained.
mask_type (str, optional): Denotes the type of mask to be created with
a Captum explainer. Valid inputs are :obj:`"edge"`, :obj:`"node"`,
and :obj:`"node_and_edge"`:
1. :obj:`"edge"`: The inputs to the forward function should be an
edge mask tensor of shape :obj:`[1, num_edges]`, a regular
:obj:`x` matrix and a regular :obj:`edge_index` matrix.
2. :obj:`"node"`: The inputs to the forward function should be a
node feature tensor of shape :obj:`[1, num_nodes, num_features]`
and a regular :obj:`edge_index` matrix.
3. :obj:`"node_and_edge"`: The inputs to the forward function
should be a node feature tensor of shape
:obj:`[1, num_nodes, num_features]`, an edge mask tensor of
shape :obj:`[1, num_edges]` and a regular :obj:`edge_index`
matrix.
For all mask types, additional arguments can be passed to the
forward function as long as the first arguments are set as
described. (default: :obj:`"edge"`)
output_idx (int, optional): Index of the output element (node or link
index) to be explained. With :obj:`output_idx` set, the forward
function will return the output of the model for the element at
the index specified. (default: :obj:`None`)
"""
return CaptumModel(model, mask_type, output_idx)
[docs]class Explainer(torch.nn.Module):
r"""An abstract class for integrating explainability into Graph Neural
Networks, *e.g.* :class:`~torch_geometric.nn.GNNExplainer` and
:class:`~torch_geometric.nn.PGExplainer`.
It also provides general visualization methods for graph attributions.
Args:
model (torch.nn.Module): The GNN module to explain.
epochs (int, optional): The number of epochs to train.
(default: :obj:`None`)
lr (float, optional): The learning rate to apply.
(default: :obj:`None`)
num_hops (int, optional): The number of hops the :obj:`model` is
aggregating information from.
If set to :obj:`None`, will automatically try to detect this
information based on the number of
:class:`~torch_geometric.nn.conv.message_passing.MessagePassing`
layers inside :obj:`model`. (default: :obj:`None`)
return_type (str, optional): Denotes the type of output from
:obj:`model`. Valid inputs are :obj:`"log_prob"` (the model
returns the logarithm of probabilities), :obj:`"prob"` (the
model returns probabilities), :obj:`"raw"` (the model returns raw
scores) and :obj:`"regression"` (the model returns scalars).
(default: :obj:`"log_prob"`)
log (bool, optional): If set to :obj:`False`, will not log any learning
progress. (default: :obj:`True`)
"""
def __init__(self, model: torch.nn.Module, lr: Optional[float] = None,
epochs: Optional[int] = None, num_hops: Optional[int] = None,
return_type: str = 'log_prob', log: bool = False):
super().__init__()
assert return_type in ['log_prob', 'prob', 'raw', 'regression']
self.model = model
self.lr = lr
self.epochs = epochs
self.num_hops = num_hops or get_num_hops(self.model)
self.return_type = return_type
self.log = log
def _flow(self) -> str:
for module in self.model.modules():
if isinstance(module, MessagePassing):
return module.flow
return 'source_to_target'
[docs] def subgraph(self, node_idx: int, x: Tensor, edge_index: Tensor, **kwargs):
r"""Returns the subgraph of the given node.
Args:
node_idx (int): The node to explain.
x (Tensor): The node feature matrix.
edge_index (LongTensor): The edge indices.
**kwargs (optional): Additional arguments passed to the GNN module.
:rtype: (Tensor, Tensor, LongTensor, LongTensor, LongTensor, dict)
"""
num_nodes, num_edges = x.size(0), edge_index.size(0)
subset, edge_index, mapping, edge_mask = k_hop_subgraph(
node_idx, self.num_hops, edge_index, relabel_nodes=True,
num_nodes=num_nodes, flow=self._flow())
x = x[subset]
kwargs_new = {}
for key, value in kwargs.items():
if torch.is_tensor(value) and value.size(0) == num_nodes:
kwargs_new[key] = value[subset]
elif torch.is_tensor(value) and value.size(0) == num_edges:
kwargs_new[key] = value[edge_mask]
kwargs_new[key] = value # TODO: this is not in PGExplainer
return x, edge_index, mapping, edge_mask, subset, kwargs_new
def _to_log_prob(self, x):
x = x.log_softmax(dim=-1) if self.return_type == 'raw' else x
x = x.log() if self.return_type == 'prob' else x
return x
[docs] @torch.no_grad()
def get_initial_prediction(self, x: Tensor, edge_index: Tensor,
batch: Optional[Tensor] = None, **kwargs):
if batch is not None:
out = self.model(x, edge_index, batch, **kwargs)
else:
out = self.model(x, edge_index, **kwargs)
if self.return_type == 'regression':
prediction = out
else:
log_logits = self._to_log_prob(out)
prediction = log_logits.argmax(dim=-1)
return prediction
[docs] def get_loss(self, out: Tensor, prediction: Tensor,
node_idx: Optional[int] = None, **kwargs):
if self.return_type == 'regression':
loss = self._loss(out, prediction, node_idx, **kwargs)
else:
log_logits = self._to_log_prob(out)
loss = self._loss(log_logits, prediction, node_idx, **kwargs)
return loss
[docs] def visualize_subgraph(self, node_idx: Optional[int], edge_index: Tensor,
edge_mask: Tensor, y: Optional[Tensor] = None,
threshold: Optional[int] = None,
edge_y: Optional[Tensor] = None,
node_alpha: Optional[Tensor] = None, seed: int = 10,
**kwargs):
r"""Visualizes the subgraph given an edge mask :attr:`edge_mask`.
Args:
node_idx (int): The node id to explain.
Set to :obj:`None` to explain a graph.
edge_index (LongTensor): The edge indices.
edge_mask (Tensor): The edge mask.
y (Tensor, optional): The ground-truth node-prediction labels used
as node colorings. All nodes will have the same color
if :attr:`node_idx` is :obj:`-1`.(default: :obj:`None`).
threshold (float, optional): Sets a threshold for visualizing
important edges. If set to :obj:`None`, will visualize all
edges with transparancy indicating the importance of edges.
(default: :obj:`None`)
edge_y (Tensor, optional): The edge labels used as edge colorings.
node_alpha (Tensor, optional): Tensor of floats (0 - 1) indicating
transparency of each node.
seed (int, optional): Random seed of the :obj:`networkx` node
placement algorithm. (default: :obj:`10`)
**kwargs (optional): Additional arguments passed to
:func:`nx.draw`.
:rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph`
"""
import matplotlib.pyplot as plt
import networkx as nx
assert edge_mask.size(0) == edge_index.size(1)
if node_idx is None or node_idx < 0:
hard_edge_mask = torch.BoolTensor([True] * edge_index.size(1),
device=edge_mask.device)
subset = torch.arange(edge_index.max().item() + 1,
device=edge_index.device)
y = None
else:
# Only operate on a k-hop subgraph around `node_idx`.
subset, edge_index, _, hard_edge_mask = k_hop_subgraph(
node_idx, self.num_hops, edge_index, relabel_nodes=True,
num_nodes=None, flow=self._flow())
edge_mask = edge_mask[hard_edge_mask]
if threshold is not None:
edge_mask = (edge_mask >= threshold).to(torch.float)
if y is None:
y = torch.zeros(edge_index.max().item() + 1,
device=edge_index.device)
else:
y = y[subset].to(torch.float) / y.max().item()
if edge_y is None:
edge_color = ['black'] * edge_index.size(1)
else:
colors = list(plt.rcParams['axes.prop_cycle'])
edge_color = [
colors[i % len(colors)]['color']
for i in edge_y[hard_edge_mask]
]
data = Data(edge_index=edge_index, att=edge_mask,
edge_color=edge_color, y=y, num_nodes=y.size(0)).to('cpu')
G = to_networkx(data, node_attrs=['y'],
edge_attrs=['att', 'edge_color'])
mapping = {k: i for k, i in enumerate(subset.tolist())}
G = nx.relabel_nodes(G, mapping)
node_args = set(signature(nx.draw_networkx_nodes).parameters.keys())
node_kwargs = {k: v for k, v in kwargs.items() if k in node_args}
node_kwargs['node_size'] = kwargs.get('node_size') or 800
node_kwargs['cmap'] = kwargs.get('cmap') or 'cool'
label_args = set(signature(nx.draw_networkx_labels).parameters.keys())
label_kwargs = {k: v for k, v in kwargs.items() if k in label_args}
label_kwargs['font_size'] = kwargs.get('font_size') or 10
pos = nx.spring_layout(G, seed=seed)
ax = plt.gca()
for source, target, data in G.edges(data=True):
ax.annotate(
'', xy=pos[target], xycoords='data', xytext=pos[source],
textcoords='data', arrowprops=dict(
arrowstyle="->",
alpha=max(data['att'], 0.1),
color=data['edge_color'],
shrinkA=sqrt(node_kwargs['node_size']) / 2.0,
shrinkB=sqrt(node_kwargs['node_size']) / 2.0,
connectionstyle="arc3,rad=0.1",
))
if node_alpha is None:
nx.draw_networkx_nodes(G, pos, node_color=y.tolist(),
**node_kwargs)
else:
node_alpha_subset = node_alpha[subset]
assert ((node_alpha_subset >= 0) & (node_alpha_subset <= 1)).all()
nx.draw_networkx_nodes(G, pos, alpha=node_alpha_subset.tolist(),
node_color=y.tolist(), **node_kwargs)
nx.draw_networkx_labels(G, pos, **label_kwargs)
return ax, G