torch_geometric.explain.algorithm.CaptumExplainer
- class CaptumExplainer(attribution_method: Union[str, Any], **kwargs)[source]
Bases:
ExplainerAlgorithm
A Captum-based explainer for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN.
This explainer algorithm uses Captum to compute attributions.
Currently, the following attribution methods are supported:
captum.attr.IntegratedGradients
captum.attr.Saliency
captum.attr.InputXGradient
captum.attr.Deconvolution
captum.attr.ShapleyValueSampling
captum.attr.GuidedBackprop
- Parameters:
attribution_method (Attribution or str) – The Captum attribution method to use. Can be a string or a
captum.attr
method.**kwargs – Additional arguments for the Captum attribution method.
- forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs) Union[Explanation, HeteroExplanation] [source]
Computes the explanation.
- Parameters:
model (torch.nn.Module) – The model to explain.
x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.
edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.
target (torch.Tensor) – The target of the model.
index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default:
None
)**kwargs (optional) – Additional keyword arguments passed to
model
.