class ExplainerAlgorithm[source]

Bases: Module

An abstract base class for implementing explainer algorithms.

abstract forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs) Union[Explanation, HeteroExplanation][source]

Computes the explanation.

  • model (torch.nn.Module) – The model to explain.

  • x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.

  • edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.

  • target (torch.Tensor) – The target of the model.

  • index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default: None)

  • **kwargs (optional) – Additional keyword arguments passed to model.

abstract supports() bool[source]

Checks if the explainer supports the user-defined settings provided in self.explainer_config, self.model_config.

property explainer_config: ExplainerConfig

Returns the connected explainer configuration.

property model_config: ModelConfig

Returns the connected model configuration.

connect(explainer_config: ExplainerConfig, model_config: ModelConfig)[source]

Connects an explainer and model configuration to the explainer algorithm.