torch_geometric.nn.models.to_captum_input

class to_captum_input(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], mask_type: Union[str, MaskLevelType], *args)[source]

Bases:

Given x, edge_index and mask_type, converts it to a format to use in Captum attribution methods. Returns inputs and additional_forward_args required for attribute functions. See to_captum_model() for example usage.

Parameters:
  • x (torch.Tensor or Dict[NodeType, torch.Tensor]) – The node features. For heterogeneous graphs this is a dictionary holding node featues for each node type.

  • edge_index (torch.Tensor or Dict[EdgeType, torch.Tensor]) – The edge indices. For heterogeneous graphs this is a dictionary holding the edge index for each edge type.

  • mask_type (str) – Denotes the type of mask to be created with a Captum explainer. Valid inputs are "edge", "node", and "node_and_edge".

  • *args – Additional forward arguments of the model being explained which will be added to additional_forward_args.