torch_geometric.nn.models.captum_output_to_dicts

class captum_output_to_dicts(captum_attrs: Tuple[Tensor, ...], mask_type: Union[str, MaskLevelType], metadata: Tuple[List[str], List[Tuple[str, str, str]]])[source]

Bases:

Convert the output of Captum attribution methods which is a tuple of attributions to two dictionaries with node and edge attribution tensors. This function is used while explaining HeteroData objects. See to_captum_model() for example usage.

Parameters:
  • captum_attrs (tuple[torch.Tensor]) – The output of attribution methods.

  • mask_type (str) –

    Denotes the type of mask to be created with a Captum explainer. Valid inputs are "edge", "node", and "node_and_edge":

    1. "edge": captum_attrs contains only edge attributions. The returned tuple has no node attributions, and an edge attribution dictionary edge types as keys and edge mask tensors of shape [num_edges] as values.

    2. "node": captum_attrs contains only node attributions. The returned tuple has a node attribution dictionary with node types as keys and node mask tensors of shape [num_nodes, num_features] as values, and no edge attributions.

    3. "node_and_edge": captum_attrs contains node and

      edge attributions.

  • metadata (Metadata) – The metadata of the heterogeneous graph.