from typing import Optional, Union
import torch
from torch_geometric.explain.algorithm.captum import (
CaptumHeteroModel,
CaptumModel,
MaskLevelType,
)
from torch_geometric.typing import Metadata
[docs]def to_captum_model(
model: torch.nn.Module,
mask_type: Union[str, MaskLevelType] = MaskLevelType.edge,
output_idx: Optional[int] = None,
metadata: Optional[Metadata] = None,
) -> Union[CaptumModel, CaptumHeteroModel]:
r"""Converts a model to a model that can be used for
`Captum <https://captum.ai/>`_ attribution methods.
Sample code for homogeneous graphs:
.. code-block:: python
from captum.attr import IntegratedGradients
from torch_geometric.data import Data
from torch_geometric.nn import GCN
from torch_geometric.nn import to_captum_model, to_captum_input
data = Data(x=(...), edge_index(...))
model = GCN(...)
... # Train the model.
# Explain predictions for node `10`:
mask_type="edge"
output_idx = 10
captum_model = to_captum_model(model, mask_type, output_idx)
inputs, additional_forward_args = to_captum_input(data.x,
data.edge_index,mask_type)
ig = IntegratedGradients(captum_model)
ig_attr = ig.attribute(inputs = inputs,
target=int(y[output_idx]),
additional_forward_args=additional_forward_args,
internal_batch_size=1)
Sample code for heterogeneous graphs:
.. code-block:: python
from captum.attr import IntegratedGradients
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv
from torch_geometric.nn import (captum_output_to_dicts,
to_captum_model, to_captum_input)
data = HeteroData(...)
model = HeteroConv(...)
... # Train the model.
# Explain predictions for node `10`:
mask_type="edge"
metadata = data.metadata
output_idx = 10
captum_model = to_captum_model(model, mask_type, output_idx, metadata)
inputs, additional_forward_args = to_captum_input(data.x_dict,
data.edge_index_dict, mask_type)
ig = IntegratedGradients(captum_model)
ig_attr = ig.attribute(inputs=inputs,
target=int(y[output_idx]),
additional_forward_args=additional_forward_args,
internal_batch_size=1)
edge_attr_dict = captum_output_to_dicts(ig_attr, mask_type, metadata)
.. note::
For an example of using a :captum:`Captum` attribution method within
:pyg:`PyG`, see `examples/explain/captum_explainer.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
explain/captum_explainer.py>`_.
Args:
model (torch.nn.Module): The model to be explained.
mask_type (str, optional): Denotes the type of mask to be created with
a :captum:`Captum` explainer. Valid inputs are :obj:`"edge"`,
:obj:`"node"`, and :obj:`"node_and_edge"`. (default: :obj:`"edge"`)
output_idx (int, optional): Index of the output element (node or link
index) to be explained. With :obj:`output_idx` set, the forward
function will return the output of the model for the element at
the index specified. (default: :obj:`None`)
metadata (Metadata, optional): The metadata of the heterogeneous graph.
Only required if explaning a
:class:`~torch_geometric.data.HeteroData` object.
(default: :obj:`None`)
"""
if metadata is None:
return CaptumModel(model, mask_type, output_idx)
else:
return CaptumHeteroModel(model, mask_type, output_idx, metadata)