Source code for torch_geometric.explain.explanation

import copy
from typing import List, Optional

from torch import Tensor

from torch_geometric.data.data import Data, warn_or_raise


[docs]class Explanation(Data): r"""Holds all the obtained explanations of a homogenous graph. The explanation object is a :obj:`~torch_geometric.data.Data` object and can hold node-attribution, edge-attribution, feature-attribution. It can also hold the original graph if needed. Args: node_mask (Tensor, optional): Node-level mask with shape :obj:`[num_nodes]`. (default: :obj:`None`) edge_mask (Tensor, optional): Edge-level mask with shape :obj:`[num_edges]`. (default: :obj:`None`) node_feat_mask (Tensor, optional): Node-level feature mask with shape :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`) edge_feat_mask (Tensor, optional): Edge-level feature mask with shape :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) **kwargs (optional): Additional attributes. """ def __init__( self, node_mask: Optional[Tensor] = None, edge_mask: Optional[Tensor] = None, node_feat_mask: Optional[Tensor] = None, edge_feat_mask: Optional[Tensor] = None, **kwargs, ): super().__init__( node_mask=node_mask, edge_mask=edge_mask, node_feat_mask=node_feat_mask, edge_feat_mask=edge_feat_mask, **kwargs, ) @property def available_explanations(self) -> List[str]: """Returns the available explanation masks.""" return [ key for key in self.keys if key.endswith('_mask') and self[key] is not None ]
[docs] def validate(self, raise_on_error: bool = True) -> bool: r"""Validates the correctness of the explanation""" status = super().validate() if 'node_mask' in self and self.num_nodes != self.node_mask.size(0): status = False warn_or_raise( f"Expected a 'node_mask' with {self.num_nodes} nodes " f"(got {self.node_mask.size(0)} nodes)", raise_on_error) if 'edge_mask' in self and self.num_edges != self.edge_mask.size(0): status = False warn_or_raise( f"Expected an 'edge_mask' with {self.num_edges} edges " f"(got {self.edge_mask.size(0)} edges)", raise_on_error) if 'node_feat_mask' in self: if 'x' in self and self.x.size() != self.node_feat_mask.size(): status = False warn_or_raise( f"Expected a 'node_feat_mask' of shape " f"{list(self.x.size())} (got shape " f"{list(self.node_feat_mask.size())})", raise_on_error) elif self.num_nodes != self.node_feat_mask.size(0): status = False warn_or_raise( f"Expected a 'node_feat_mask' with {self.num_nodes} nodes " f"(got {self.node_feat_mask.size(0)} nodes)", raise_on_error) if 'edge_feat_mask' in self: if ('edge_attr' in self and self.edge_attr.size() != self.edge_feat_mask.size()): status = False warn_or_raise( f"Expected an 'edge_feat_mask' of shape " f"{list(self.edge_attr.size())} (got shape " f"{list(self.edge_feat_mask.size())})", raise_on_error) elif self.num_edges != self.edge_feat_mask.size(0): status = False warn_or_raise( f"Expected an 'edge_feat_mask' with {self.num_edges} " f"edges (got {self.edge_feat_mask.size(0)} edges)", raise_on_error) return status
[docs] def get_explanation_subgraph(self) -> 'Explanation': r"""Returns the induced subgraph, in which all nodes and edges with zero attribution are masked out.""" return self._apply_masks( node_mask=self.node_mask > 0 if 'node_mask' in self else None, edge_mask=self.edge_mask > 0 if 'edge_mask' in self else None, )
[docs] def get_complement_subgraph(self) -> 'Explanation': r"""Returns the induced subgraph, in which all nodes and edges with any attribution are masked out.""" return self._apply_masks( node_mask=self.node_mask == 0 if 'node_mask' in self else None, edge_mask=self.edge_mask == 0 if 'edge_mask' in self else None, )
def _apply_masks( self, node_mask: Optional[Tensor] = None, edge_mask: Optional[Tensor] = None, ) -> 'Explanation': out = copy.copy(self) if edge_mask is not None: for key, value in self.items(): if key == 'edge_index': out.edge_index = value[:, edge_mask] elif self.is_edge_attr(key): out[key] = value[edge_mask] if node_mask is not None: out = out.subgraph(node_mask) return out