torch_geometric.explain

Warning

This module is in active development and may not be stable. Access requires installing from master.

Philoshopy

This module provides a set of tools to explain the predictions of a PyG model or to explain the underlying phenomenon of a dataset (see the “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” paper for more details).

We represent explanations using the torch_geometric.explain.Explanation class, which is a Data object containing masks for the nodes, edges, features and any attributes of the data.

The torch_geometric.explain.Explainer class is designed to handle all explainability parameters (see the torch_geometric.explain.config.ExplainerConfig class for more details):

  • which algorithm from the torch_geometric.explain.algorithm module to use (e.g., GNNExplainer)

  • the type of explanation to compute (e.g., explanation_type="phenomenon" or explanation_type="model")

  • the different type of masks for node and edges (e.g., mask="object" or mask="attributes")

  • any postprocessing of the masks (e.g., threshold_type="topk" or threshold_type="hard")

This class allows the user to easily compare different explainability methods and to easily switch between different types of masks, while making sure the high-level framework stays the same.

Explainer

class Explainer(model: Module, algorithm: ExplainerAlgorithm, explanation_type: Union[ExplanationType, str], model_config: Union[ModelConfig, Dict[str, Any]], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None, threshold_config: Optional[ThresholdConfig] = None)[source]

Bases: object

An explainer class for instance-level explanations of Graph Neural Networks.

Parameters
  • model (torch.nn.Module) – The model to explain.

  • algorithm (ExplainerAlgorithm) – The explanation algorithm.

  • explanation_type (ExplanationType or str) –

    The type of explanation to compute. The possible values are:

    • "model": Explains the model prediction.

    • "phenomenon": Explains the phenomenon that the model is trying to predict.

    In practice, this means that the explanation algorithm will either compute their losses with respect to the model output ("model") or the target output ("phenomenon").

  • model_config (ModelConfig) – The model configuration. See ModelConfig for available options. (default: None)

  • node_mask_type (MaskType or str, optional) –

    The type of mask to apply on nodes. The possible values are (default: None):

    • None: Will not apply any mask on nodes.

    • "object": Will mask each node.

    • "common_attributes": Will mask each feature.

    • "attributes": Will mask each feature across all nodes.

  • edge_mask_type (MaskType or str, optional) – The type of mask to apply on edges. Has the sample possible values as node_mask_type. (default: None)

  • threshold_config (ThresholdConfig, optional) – The threshold configuration. See ThresholdConfig for available options. (default: None)

get_prediction(*args, **kwargs) Tensor[source]

Returns the prediction of the model on the input graph.

If the model mode is "regression", the prediction is returned as a scalar value. If the model mode is "multiclass_classification" or "binary_classification", the prediction is returned as the predicted class label.

Parameters
  • *args – Arguments passed to the model.

  • **kwargs (optional) – Additional keyword arguments passed to the model.

get_masked_prediction(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], node_mask: Optional[Union[Tensor, Dict[str, Tensor]]] = None, edge_mask: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs) Tensor[source]

Returns the prediction of the model on the input graph with node and edge masks applied.

__call__(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, **kwargs) Union[Explanation, HeteroExplanation][source]

Computes the explanation of the GNN for the given inputs and target.

Note

If you get an error message like “Trying to backward through the graph a second time”, make sure that the target you provided was computed with torch.no_grad().

Parameters
  • x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.

  • edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.

  • target (torch.Tensor) – The target of the model. If the explanation type is "phenomenon", the target has to be provided. If the explanation type is "model", the target should be set to None and will get automatically inferred. (default: None)

  • index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default: None)

  • **kwargs – additional arguments to pass to the GNN.

get_target(prediction: Tensor) Tensor[source]

Returns the target of the model from a given prediction.

If the model mode is of type "regression", the prediction is returned as it is. If the model mode is of type "multiclass_classification" or "binary_classification", the prediction is returned as the predicted class label.

class ExplainerConfig(explanation_type: Union[ExplanationType, str], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None)[source]

Configuration class to store and validate high level explanation parameters.

Parameters
  • explanation_type (ExplanationType or str) –

    The type of explanation to compute. The possible values are:

    • "model": Explains the model prediction.

    • "phenomenon": Explains the phenomenon that the model is trying to predict.

    In practice, this means that the explanation algorithm will either compute their losses with respect to the model output ("model") or the target output ("phenomenon").

  • node_mask_type (MaskType or str, optional) –

    The type of mask to apply on nodes. The possible values are (default: None):

    • None: Will not apply any mask on nodes.

    • "object": Will mask each node.

    • "common_attributes": Will mask each feature.

    • "attributes": Will mask each feature across all nodes.

  • edge_mask_type (MaskType or str, optional) – The type of mask to apply on edges. Has the sample possible values as node_mask_type. (default: None)

class ModelConfig(mode: Union[ModelMode, str], task_level: Union[ModelTaskLevel, str], return_type: Optional[Union[ModelReturnType, str]] = None)[source]

Configuration class to store model parameters.

Parameters
  • mode (ModelMode or str) –

    The mode of the model. The possible values are:

    • "binary_classification": A binary classification model.

    • "multiclass_classification": A multiclass classification model.

    • "regression": A regression model.

  • task_level (ModelTaskLevel or str) –

    The task-level of the model. The possible values are:

    • "node": A node-level prediction model.

    • "edge": An edge-level prediction model.

    • "graph": A graph-level prediction model.

  • return_type (ModelReturnType or str, optional) –

    The return type of the model. The possible values are (default: None):

    • "raw": The model returns raw values.

    • "probs": The model returns probabilities.

    • "log_probs": The model returns log-probabilities.

class ThresholdConfig(threshold_type: Union[ThresholdType, str], value: Union[float, int])[source]

Configuration class to store and validate threshold parameters.

Parameters
  • threshold_type (ThresholdType or str) –

    The type of threshold to apply. The possible values are:

    • None: No threshold is applied.

    • "hard": A hard threshold is applied to each mask. The elements of the mask with a value below the value are set to 0, the others are set to 1.

    • "topk": A soft threshold is applied to each mask. The top obj:value elements of each mask are kept, the others are set to 0.

    • "topk_hard": Same as "topk" but values are set to 1 for all elements which are kept.

  • value (int or float, optional) – The value to use when thresholding. (default: None)

Explanations

class Explanation(x: Optional[Tensor] = None, edge_index: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, y: Optional[Tensor] = None, pos: Optional[Tensor] = None, **kwargs)[source]

Bases: Data, ExplanationMixin

Holds all the obtained explanations of a homogeneous graph.

The explanation object is a Data object and can hold node attributions and edge attributions. It can also hold the original graph if needed.

Parameters
  • node_mask (Tensor, optional) – Node-level mask with shape [num_nodes, 1], [1, num_features] or [num_nodes, num_features]. (default: None)

  • edge_mask (Tensor, optional) – Edge-level mask with shape [num_edges]. (default: None)

  • **kwargs (optional) – Additional attributes.

validate(raise_on_error: bool = True) bool[source]

Validates the correctness of the Explanation object.

get_explanation_subgraph() Explanation[source]

Returns the induced subgraph, in which all nodes and edges with zero attribution are masked out.

get_complement_subgraph() Explanation[source]

Returns the induced subgraph, in which all nodes and edges with any attribution are masked out.

visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[List[str]] = None, top_k: Optional[int] = None)[source]

Creates a bar plot of the node features importance by summing up self.node_mask across all nodes.

Parameters
  • path (str, optional) – The path to where the plot is saved. If set to None, will visualize the plot on-the-fly. (default: None)

  • feat_labels (List[str], optional) – Optional labels for features. (default None)

  • top_k (int, optional) – Top k features to plot. If None plots all features. (default: None)

visualize_graph(path: Optional[str] = None, backend: Optional[str] = None)[source]

Visualizes the explanation graph with edge opacity corresponding to edge importance.

Parameters
  • path (str, optional) – The path to where the plot is saved. If set to None, will visualize the plot on-the-fly. (default: None)

  • backend (str, optional) – The graph drawing backend to use for visualization ("graphviz", "networkx"). If set to None, will use the most appropriate visualization backend based on available system packages. (default: None)

class HeteroExplanation(_mapping: Optional[Dict[str, Any]] = None, **kwargs)[source]

Bases: HeteroData, ExplanationMixin

Holds all the obtained explanations of a heterogeneous graph.

The explanation object is a HeteroData object and can hold node attributions and edge attributions. It can also hold the original graph if needed.

validate(raise_on_error: bool = True) bool[source]

Validates the correctness of the Explanation object.

get_explanation_subgraph() HeteroExplanation[source]

Returns the induced subgraph, in which all nodes and edges with zero attribution are masked out.

get_complement_subgraph() HeteroExplanation[source]

Returns the induced subgraph, in which all nodes and edges with any attribution are masked out.

Explainer Algorithms

ExplainerAlgorithm

An abstract base class for implementing explainer algorithms.

DummyExplainer

A dummy explainer that returns random explanations (useful for testing purposes).

GNNExplainer

The GNN-Explainer model from the "GNNExplainer: Generating Explanations for Graph Neural Networks" paper for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN.

CaptumExplainer

A Captum-based explainer for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN.

PGExplainer

The PGExplainer model from the "Parameterized Explainer for Graph Neural Network" paper.

AttentionExplainer

An explainer that uses the attention coefficients produced by an attention-based GNN (e.g., GATConv, GATv2Conv, or TransformerConv) as edge explanation.

Explanation Metrics

The quality of an explanation can be judged by a variety of different methods. PyG supports the following metrics out-of-the-box:

groundtruth_metrics

Compares and evaluates an explanation mask with the ground-truth explanation mask.

fidelity

Evaluates the fidelity of an Explainer given an Explanation, as described in the "GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" paper.

characterization_score

Returns the componentwise characterization score as described in the "GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" paper:

fidelity_curve_auc

Returns the AUC for the fidelity curve as described in the "GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" paper.

unfaithfulness

Evaluates how faithful an Explanation is to an underyling GNN predictor, as described in the "Evaluating Explainability for Graph Neural Networks" paper.