import copy
from typing import Dict, List, Optional, Union
import torch
from torch import Tensor
from torch_geometric.data.data import Data, warn_or_raise
from torch_geometric.data.hetero_data import HeteroData
from torch_geometric.explain.config import ThresholdConfig, ThresholdType
from torch_geometric.typing import EdgeType, NodeType
from torch_geometric.visualization import visualize_graph
class ExplanationMixin:
@property
def available_explanations(self) -> List[str]:
"""Returns the available explanation masks."""
return [key for key in self.keys() if key.endswith('_mask')]
def validate_masks(self, raise_on_error: bool = True) -> bool:
r"""Validates the correctness of the :class:`Explanation` masks."""
status = True
for store in self.node_stores:
if 'node_mask' not in store:
continue
if store.node_mask.dim() != 2:
status = False
warn_or_raise(
f"Expected a 'node_mask' with two dimensions (got "
f"{store.node_mask.dim()} dimensions)", raise_on_error)
if store.node_mask.size(0) not in {1, store.num_nodes}:
status = False
warn_or_raise(
f"Expected a 'node_mask' with {store.num_nodes} nodes "
f"(got {store.node_mask.size(0)} nodes)", raise_on_error)
if 'x' in store:
num_features = store.x.size(-1)
else:
num_features = store.node_mask.size(-1)
if store.node_mask.size(1) not in {1, num_features}:
status = False
warn_or_raise(
f"Expected a 'node_mask' with {num_features} features ("
f"got {store.node_mask.size(1)} features)", raise_on_error)
for store in self.edge_stores:
if 'edge_mask' not in store:
continue
if store.edge_mask.dim() != 1:
status = False
warn_or_raise(
f"Expected an 'edge_mask' with one dimension (got "
f"{store.edge_mask.dim()} dimensions)", raise_on_error)
if store.edge_mask.size(0) != store.num_edges:
status = False
warn_or_raise(
f"Expected an 'edge_mask' with {store.num_edges} edges "
f"(got {store.edge_mask.size(0)} edges)", raise_on_error)
return status
def _threshold_mask(
self,
mask: Optional[Tensor],
threshold_config: ThresholdConfig,
) -> Optional[Tensor]:
if mask is None:
return None
if threshold_config.type == ThresholdType.hard:
return (mask > threshold_config.value).float()
if threshold_config.type in [
ThresholdType.topk,
ThresholdType.topk_hard,
]:
if threshold_config.value >= mask.numel():
if threshold_config.type == ThresholdType.topk:
return mask
else:
return torch.ones_like(mask)
value, index = torch.topk(
mask.flatten(),
k=threshold_config.value,
)
out = torch.zeros_like(mask.flatten())
if threshold_config.type == ThresholdType.topk:
out[index] = value
else:
out[index] = 1.0
return out.view(mask.size())
assert False
def threshold(
self,
*args,
**kwargs,
) -> Union['Explanation', 'HeteroExplanation']:
"""Thresholds the explanation masks according to the thresholding
method.
Args:
*args: Arguments passed to :class:`ThresholdConfig`.
**kwargs: Keyword arguments passed to :class:`ThresholdConfig`.
"""
threshold_config = ThresholdConfig.cast(*args, **kwargs)
if threshold_config is None:
return self
# Avoid modification of the original explanation:
out = copy.copy(self)
for store in out.node_stores:
store.node_mask = self._threshold_mask(store.get('node_mask'),
threshold_config)
for store in out.edge_stores:
store.edge_mask = self._threshold_mask(store.get('edge_mask'),
threshold_config)
return out
[docs]class Explanation(Data, ExplanationMixin):
r"""Holds all the obtained explanations of a homogeneous graph.
The explanation object is a :obj:`~torch_geometric.data.Data` object and
can hold node attributions and edge attributions.
It can also hold the original graph if needed.
Args:
node_mask (Tensor, optional): Node-level mask with shape
:obj:`[num_nodes, 1]`, :obj:`[1, num_features]` or
:obj:`[num_nodes, num_features]`. (default: :obj:`None`)
edge_mask (Tensor, optional): Edge-level mask with shape
:obj:`[num_edges]`. (default: :obj:`None`)
**kwargs (optional): Additional attributes.
"""
[docs] def validate(self, raise_on_error: bool = True) -> bool:
r"""Validates the correctness of the :class:`Explanation` object."""
status = super().validate(raise_on_error)
status &= self.validate_masks(raise_on_error)
return status
[docs] def get_explanation_subgraph(self) -> 'Explanation':
r"""Returns the induced subgraph, in which all nodes and edges with
zero attribution are masked out.
"""
node_mask = self.get('node_mask')
if node_mask is not None:
node_mask = node_mask.sum(dim=-1) > 0
edge_mask = self.get('edge_mask')
if edge_mask is not None:
edge_mask = edge_mask > 0
return self._apply_masks(node_mask, edge_mask)
[docs] def get_complement_subgraph(self) -> 'Explanation':
r"""Returns the induced subgraph, in which all nodes and edges with any
attribution are masked out.
"""
node_mask = self.get('node_mask')
if node_mask is not None:
node_mask = node_mask.sum(dim=-1) == 0
edge_mask = self.get('edge_mask')
if edge_mask is not None:
edge_mask = edge_mask == 0
return self._apply_masks(node_mask, edge_mask)
def _apply_masks(
self,
node_mask: Optional[Tensor] = None,
edge_mask: Optional[Tensor] = None,
) -> 'Explanation':
out = copy.copy(self)
if edge_mask is not None:
for key, value in self.items():
if key == 'edge_index':
out.edge_index = value[:, edge_mask]
elif self.is_edge_attr(key):
out[key] = value[edge_mask]
if node_mask is not None:
out = out.subgraph(node_mask)
return out
[docs] def visualize_feature_importance(
self,
path: Optional[str] = None,
feat_labels: Optional[List[str]] = None,
top_k: Optional[int] = None,
):
r"""Creates a bar plot of the node feature importances by summing up
the node mask across all nodes.
Args:
path (str, optional): The path to where the plot is saved.
If set to :obj:`None`, will visualize the plot on-the-fly.
(default: :obj:`None`)
feat_labels (List[str], optional): The labels of features.
(default :obj:`None`)
top_k (int, optional): Top k features to plot. If :obj:`None`
plots all features. (default: :obj:`None`)
"""
node_mask = self.get('node_mask')
if node_mask is None:
raise ValueError(f"The attribute 'node_mask' is not available "
f"in '{self.__class__.__name__}' "
f"(got {self.available_explanations})")
if node_mask.dim() != 2 or node_mask.size(1) <= 1:
raise ValueError(f"Cannot compute feature importance for "
f"object-level 'node_mask' "
f"(got shape {node_mask.size()})")
if feat_labels is None:
feat_labels = range(node_mask.size(1))
score = node_mask.sum(dim=0)
return _visualize_score(score, feat_labels, path, top_k)
[docs] def visualize_graph(
self,
path: Optional[str] = None,
backend: Optional[str] = None,
node_labels: Optional[List[str]] = None,
) -> None:
r"""Visualizes the explanation graph with edge opacity corresponding to
edge importance.
Args:
path (str, optional): The path to where the plot is saved.
If set to :obj:`None`, will visualize the plot on-the-fly.
(default: :obj:`None`)
backend (str, optional): The graph drawing backend to use for
visualization (:obj:`"graphviz"`, :obj:`"networkx"`).
If set to :obj:`None`, will use the most appropriate
visualization backend based on available system packages.
(default: :obj:`None`)
node_labels (list[str], optional): The labels/IDs of nodes.
(default: :obj:`None`)
"""
edge_mask = self.get('edge_mask')
if edge_mask is None:
raise ValueError(f"The attribute 'edge_mask' is not available "
f"in '{self.__class__.__name__}' "
f"(got {self.available_explanations})")
visualize_graph(self.edge_index, edge_mask, path, backend, node_labels)
[docs]class HeteroExplanation(HeteroData, ExplanationMixin):
r"""Holds all the obtained explanations of a heterogeneous graph.
The explanation object is a :obj:`~torch_geometric.data.HeteroData` object
and can hold node attributions and edge attributions.
It can also hold the original graph if needed.
"""
[docs] def validate(self, raise_on_error: bool = True) -> bool:
r"""Validates the correctness of the :class:`Explanation` object."""
status = super().validate(raise_on_error)
status &= self.validate_masks(raise_on_error)
return status
[docs] def get_explanation_subgraph(self) -> 'HeteroExplanation':
r"""Returns the induced subgraph, in which all nodes and edges with
zero attribution are masked out.
"""
return self._apply_masks(
node_mask_dict={
key: mask.sum(dim=-1) > 0
for key, mask in self.collect('node_mask', True).items()
},
edge_mask_dict={
key: mask > 0
for key, mask in self.collect('edge_mask', True).items()
},
)
[docs] def get_complement_subgraph(self) -> 'HeteroExplanation':
r"""Returns the induced subgraph, in which all nodes and edges with any
attribution are masked out.
"""
return self._apply_masks(
node_mask_dict={
key: mask.sum(dim=-1) == 0
for key, mask in self.collect('node_mask', True).items()
},
edge_mask_dict={
key: mask == 0
for key, mask in self.collect('edge_mask', True).items()
},
)
def _apply_masks(
self,
node_mask_dict: Dict[NodeType, Tensor],
edge_mask_dict: Dict[EdgeType, Tensor],
) -> 'HeteroExplanation':
out = copy.copy(self)
for edge_type, edge_mask in edge_mask_dict.items():
for key, value in self[edge_type].items():
if key == 'edge_index':
out[edge_type].edge_index = value[:, edge_mask]
elif self[edge_type].is_edge_attr(key):
out[edge_type][key] = value[edge_mask]
return out.subgraph(node_mask_dict)
[docs] def visualize_feature_importance(
self,
path: Optional[str] = None,
feat_labels: Optional[Dict[NodeType, List[str]]] = None,
top_k: Optional[int] = None,
):
r"""Creates a bar plot of the node feature importances by summing up
node masks across all nodes for each node type.
Args:
path (str, optional): The path to where the plot is saved.
If set to :obj:`None`, will visualize the plot on-the-fly.
(default: :obj:`None`)
feat_labels (Dict[NodeType, List[str]], optional): The labels of
features for each node type. (default :obj:`None`)
top_k (int, optional): Top k features to plot. If :obj:`None`
plots all features. (default: :obj:`None`)
"""
node_mask_dict = self.node_mask_dict
for node_mask in node_mask_dict.values():
if node_mask.dim() != 2:
raise ValueError(f"Cannot compute feature importance for "
f"object-level 'node_mask' "
f"(got shape {node_mask.size()})")
if feat_labels is None:
feat_labels = {}
for node_type, node_mask in node_mask_dict.items():
feat_labels[node_type] = range(node_mask.size(1))
score = torch.cat(
[node_mask.sum(dim=0) for node_mask in node_mask_dict.values()],
dim=0)
all_feat_labels = []
for node_type in node_mask_dict.keys():
all_feat_labels += [
f'{node_type}#{label}' for label in feat_labels[node_type]
]
return _visualize_score(score, all_feat_labels, path, top_k)
def _visualize_score(
score: torch.Tensor,
labels: List[str],
path: Optional[str] = None,
top_k: Optional[int] = None,
):
import matplotlib.pyplot as plt
import pandas as pd
if len(labels) != score.numel():
raise ValueError(f"The number of labels (got {len(labels)}) must "
f"match the number of scores (got {score.numel()})")
score = score.cpu().numpy()
df = pd.DataFrame({'score': score}, index=labels)
df = df.sort_values('score', ascending=False)
df = df.round(decimals=3)
if top_k is not None:
df = df.head(top_k)
title = f"Feature importance for top {len(df)} features"
else:
title = f"Feature importance for {len(df)} features"
ax = df.plot(
kind='barh',
figsize=(10, 7),
title=title,
ylabel='Feature label',
xlim=[0, float(df['score'].max()) + 0.3],
legend=False,
)
plt.gca().invert_yaxis()
ax.bar_label(container=ax.containers[0], label_type='edge')
if path is not None:
plt.savefig(path)
else:
plt.show()
plt.close()