torch_geometric.explain

Warning

This module is in active development and may not be stable. Access requires installing PyTorch Geometric from master.

Philoshopy

This module provides a set of tools to explain the predictions of a PyG model or to explain the underlying phenomenon of a dataset (see the “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” paper for more details).

We represent explanations using the torch_geometric.explain.Explanation class, which is a Data object containing masks for the nodes, edges, features and any attributes of the data.

The torch_geometric.explain.Explainer class is designed to handle all explainability parameters (see the torch_geometric.explain.config.ExplainerConfig class for more details):

  • which algorithm from the torch_geometric.explain.algorithm module to use (e.g., GNNExplainer)

  • the type of explanation to compute (e.g., explanation_type="phenomenon" or explanation_type="model")

  • the different type of masks for node and edges (e.g., mask="object" or mask="attributes")

  • any postprocessing of the masks (e.g., threshold_type="topk" or threshold_type="hard")

This class allows the user to easily compare different explainability methods and to easily switch between different types of masks, while making sure the high-level framework stays the same.

Explainer

class Explainer(model: Module, algorithm: ExplainerAlgorithm, explainer_config: ExplainerConfig, model_config: ModelConfig, threshold_config: Optional[ThresholdConfig] = None)[source]

An explainer class for instance-level explanations of Graph Neural Networks.

Parameters
__call__(x: Tensor, edge_index: Tensor, *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, target_index: Optional[int] = None, **kwargs) Explanation[source]

Computes the explanation of the GNN for the given inputs and target.

Note

If you get an error message like “Trying to backward through the graph a second time”, make sure that the target you provided was computed with torch.no_grad().

Parameters
  • x (torch.Tensor) – The input node features.

  • edge_index (torch.Tensor) – The input edge indices.

  • target (torch.Tensor) – The target of the model. If the explanation type is "phenomenon", the target has to be provided. If the explanation type is "model", the target should be set to None and will get automatically inferred. (default: None)

  • index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default: None)

  • target_index (int, optional) – The index of the model outputs to reference in case the model returns a list of tensors, e.g., in a multi-task learning scenario. Should be kept to None in case the model only returns a single output tensor. (default: None)

  • **kwargs – additional arguments to pass to the GNN.

get_prediction(*args, **kwargs) Tensor[source]

Returns the prediction of the model on the input graph.

If the model mode is "regression", the prediction is returned as a scalar value. If the model mode "classification", the prediction is returned as the predicted class label.

Parameters
  • *args – Arguments passed to the model.

  • **kwargs (optional) – Additional keyword arguments passed to the model.

class ExplainerConfig(explanation_type: Union[ExplanationType, str], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None)[source]

Configuration class to store and validate high level explanation parameters.

Parameters
  • explanation_type (ExplanationType or str) –

    The type of explanation to compute. The possible values are:

    • "model": Explains the model prediction.

    • "phenomenon": Explains the phenomenon that the model is trying to predict.

    In practice, this means that the explanation algorithm will either compute their losses with respect to the model output or the target output.

  • node_mask_type (MaskType or str, optional) –

    The type of mask to apply on nodes. The possible values are (default: None):

    • None: Will not apply any mask on nodes.

    • "object": Will mask each node.

    • "common_attributes": Will mask each feature.

    • "attributes": Will mask each feature across all nodes.

  • edge_mask_type (MaskType or str, optional) – The type of mask to apply on edges. Same types as node_mask_type. (default: None)

class ModelConfig(mode: Union[ModelMode, str], task_level: Union[ModelTaskLevel, str], return_type: Optional[Union[ModelReturnType, str]] = None)[source]

Configuration class to store model parameters.

Parameters
  • mode (ModelMode or str) –

    The mode of the model. The possible values are:

    • "classification": A classification model.

    • "regression": A regression model.

  • task_level (ModelTaskLevel or str) –

    The task-level of the model. The possible values are:

    • "node": A node-level prediction model.

    • "edge": An edge-level prediction model.

    • "graph": A graph-level prediction model.

  • return_type (ModelReturnType or str, optional) –

    The return type of the model. The possible values are (default: None):

    • "raw": The model returns raw values.

    • "probs": The model returns probabilities.

    • "log_probs": The model returns log-probabilities.

class ThresholdConfig(threshold_type: Union[ThresholdType, str], value: Union[float, int])[source]

Configuration class to store and validate threshold parameters.

Parameters
  • threshold_type (ThresholdType or str) –

    The type of threshold to apply. The possible values are:

    • None: No threshold is applied.

    • "hard": A hard threshold is applied to each mask. The elements of the mask with a value below the value are set to 0, the others are set to 1.

    • "topk": A soft threshold is applied to each mask. The top obj:value elements of each mask are kept, the others are set to 0.

    • "topk_hard": Aame as "topk" but values are set to 1 for all elements which are kept.

  • value (int or float, optional) – The value to use when thresholding. (default: None)

Explanations

class Explanation(node_mask: Optional[Tensor] = None, edge_mask: Optional[Tensor] = None, node_feat_mask: Optional[Tensor] = None, edge_feat_mask: Optional[Tensor] = None, **kwargs)[source]

Holds all the obtained explanations of a homogenous graph.

The explanation object is a Data object and can hold node-attribution, edge-attribution, feature-attribution. It can also hold the original graph if needed.

Parameters
  • node_mask (Tensor, optional) – Node-level mask with shape [num_nodes]. (default: None)

  • edge_mask (Tensor, optional) – Edge-level mask with shape [num_edges]. (default: None)

  • node_feat_mask (Tensor, optional) – Node-level feature mask with shape [num_nodes, num_node_features]. (default: None)

  • edge_feat_mask (Tensor, optional) – Edge-level feature mask with shape [num_edges, num_edge_features]. (default: None)

  • **kwargs (optional) – Additional attributes.

property available_explanations: List[str]

Returns the available explanation masks.

validate(raise_on_error: bool = True) bool[source]

Validates the correctness of the explanation

get_explanation_subgraph() Explanation[source]

Returns the induced subgraph, in which all nodes and edges with zero attribution are masked out.

get_complement_subgraph() Explanation[source]

Returns the induced subgraph, in which all nodes and edges with any attribution are masked out.

Explainer Algorithms

ExplainerAlgorithm

Abstract base class for explainer algorithms.

DummyExplainer

A dummy explainer for testing purposes.

GNNExplainer

The GNN-Explainer model from the "GNNExplainer: Generating Explanations for Graph Neural Networks" paper for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN.

class ExplainerAlgorithm[source]

Abstract base class for explainer algorithms.

abstract forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, target_index: Optional[int] = None, **kwargs) Explanation[source]

Computes the explanation.

Parameters
  • model (torch.nn.Module) – The model to explain.

  • x (torch.Tensor) – The input node features.

  • edge_index (torch.Tensor) – The input edge indices.

  • target (torch.Tensor) – The target of the model.

  • index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default: None)

  • target_index (int, optional) – The index of the model outputs to reference in case the model returns a list of tensors, e.g., in a multi-task learning scenario. Should be kept to None in case the model only returns a single output tensor. (default: None)

  • **kwargs (optional) – Additional keyword arguments passed to model.

abstract supports() bool[source]

Checks if the explainer supports the user-defined settings provided in self.explainer_config and self.model_config.

property explainer_config: ExplainerConfig

Returns the connected explainer configuration.

property model_config: ModelConfig

Returns the connected model configuration.

connect(explainer_config: ExplainerConfig, model_config: ModelConfig)[source]

Connects an explainer and model configuration to the explainer algorithm.

class DummyExplainer[source]

A dummy explainer for testing purposes.

class GNNExplainer(epochs: int = 100, lr: float = 0.01, **kwargs)[source]

The GNN-Explainer model from the “GNNExplainer: Generating Explanations for Graph Neural Networks” paper for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN.

The following configurations are currently supported:

Parameters
  • epochs (int, optional) – The number of epochs to train. (default: 100)

  • lr (float, optional) – The learning rate to apply. (default: 0.01)

  • **kwargs (optional) – Additional hyper-parameters to override default settings in coeffs.