torch_geometric.nn.models.captum_output_to_dicts
- class captum_output_to_dicts(captum_attrs: Tuple[Tensor, ...], mask_type: Union[str, MaskLevelType], metadata: Tuple[List[str], List[Tuple[str, str, str]]])[source]
Bases:
Convert the output of Captum attribution methods which is a tuple of attributions to two dictionaries with node and edge attribution tensors. This function is used while explaining
HeteroData
objects. Seeto_captum_model()
for example usage.- Parameters:
captum_attrs (tuple[torch.Tensor]) – The output of attribution methods.
mask_type (str) –
Denotes the type of mask to be created with a Captum explainer. Valid inputs are
"edge"
,"node"
, and"node_and_edge"
:"edge"
:captum_attrs
contains only edge attributions. The returned tuple has no node attributions, and an edge attribution dictionary edge types as keys and edge mask tensors of shape[num_edges]
as values."node"
:captum_attrs
contains only node attributions. The returned tuple has a node attribution dictionary with node types as keys and node mask tensors of shape[num_nodes, num_features]
as values, and no edge attributions."node_and_edge"
:captum_attrs
contains node andedge attributions.
metadata (Metadata) – The metadata of the heterogeneous graph.