torch_geometric.nn.models.to_captum_model

class to_captum_model(model: Module, mask_type: Union[str, MaskLevelType] = MaskLevelType.edge, output_idx: Optional[int] = None, metadata: Optional[Tuple[List[str], List[Tuple[str, str, str]]]] = None)[source]

Bases:

Converts a model to a model that can be used for Captum attribution methods.

Sample code for homogeneous graphs:

from captum.attr import IntegratedGradients

from torch_geometric.data import Data
from torch_geometric.nn import GCN
from torch_geometric.nn import to_captum_model, to_captum_input

data = Data(x=(...), edge_index(...))
model = GCN(...)
...  # Train the model.

# Explain predictions for node `10`:
mask_type="edge"
output_idx = 10
captum_model = to_captum_model(model, mask_type, output_idx)
inputs, additional_forward_args = to_captum_input(data.x,
                                    data.edge_index,mask_type)

ig = IntegratedGradients(captum_model)
ig_attr = ig.attribute(inputs = inputs,
                       target=int(y[output_idx]),
                       additional_forward_args=additional_forward_args,
                       internal_batch_size=1)

Sample code for heterogeneous graphs:

from captum.attr import IntegratedGradients

from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv
from torch_geometric.nn import (captum_output_to_dicts,
                                to_captum_model, to_captum_input)

data = HeteroData(...)
model = HeteroConv(...)
...  # Train the model.

# Explain predictions for node `10`:
mask_type="edge"
metadata = data.metadata
output_idx = 10
captum_model = to_captum_model(model, mask_type, output_idx, metadata)
inputs, additional_forward_args = to_captum_input(data.x_dict,
                                    data.edge_index_dict, mask_type)

ig = IntegratedGradients(captum_model)
ig_attr = ig.attribute(inputs=inputs,
                       target=int(y[output_idx]),
                       additional_forward_args=additional_forward_args,
                       internal_batch_size=1)
edge_attr_dict = captum_output_to_dicts(ig_attr, mask_type, metadata)

Note

For an example of using a attribution method within , see examples/explain/captum_explainer.py.

Parameters:
  • model (torch.nn.Module) – The model to be explained.

  • mask_type (str, optional) – Denotes the type of mask to be created with a explainer. Valid inputs are "edge", "node", and "node_and_edge". (default: "edge")

  • output_idx (int, optional) – Index of the output element (node or link index) to be explained. With output_idx set, the forward function will return the output of the model for the element at the index specified. (default: None)

  • metadata (Metadata, optional) – The metadata of the heterogeneous graph. Only required if explaning a HeteroData object. (default: None)