torch_geometric.nn.models.to_captum_input
- class to_captum_input(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], mask_type: Union[str, MaskLevelType], *args)[source]
Bases:
Given
x
,edge_index
andmask_type
, converts it to a format to use in Captum attribution methods. Returnsinputs
andadditional_forward_args
required for Captum’sattribute
functions. Seeto_captum_model()
for example usage.- Parameters:
x (torch.Tensor or Dict[NodeType, torch.Tensor]) – The node features. For heterogeneous graphs this is a dictionary holding node featues for each node type.
edge_index (torch.Tensor or Dict[EdgeType, torch.Tensor]) – The edge indices. For heterogeneous graphs this is a dictionary holding the
edge index
for each edge type.mask_type (str) – Denotes the type of mask to be created with a Captum explainer. Valid inputs are
"edge"
,"node"
, and"node_and_edge"
.*args – Additional forward arguments of the model being explained which will be added to
additional_forward_args
.