torch_geometric.explain.algorithm.PGExplainer

class PGExplainer(epochs: int, lr: float = 0.003, **kwargs)[source]

Bases: ExplainerAlgorithm

The PGExplainer model from the “Parameterized Explainer for Graph Neural Network” paper.

Internally, it utilizes a neural network to identify subgraph structures that play a crucial role in the predictions made by a GNN. Importantly, the PGExplainer needs to be trained via train() before being able to generate explanations:

explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=30, lr=0.003),
    explanation_type='phenomenon',
    edge_mask_type='object',
    model_config=ModelConfig(...),
)

# Train against a variety of node-level or graph-level predictions:
for epoch in range(30):
    for index in [...]:  # Indices to train against.
        loss = explainer.algorithm.train(epoch, model, x, edge_index,
                                         target=target, index=index)

# Get the final explanations:
explanation = explainer(x, edge_index, target=target, index=0)
Parameters:
  • epochs (int) – The number of epochs to train.

  • lr (float, optional) – The learning rate to apply. (default: 0.003).

  • **kwargs (optional) – Additional hyper-parameters to override default settings in coeffs.

reset_parameters()[source]

Resets all learnable parameters of the module.

train(epoch: int, model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)[source]

Trains the underlying explainer model. Needs to be called before being able to make predictions.

Parameters:
  • epoch (int) – The current epoch of the training phase.

  • model (torch.nn.Module) – The model to explain.

  • x (torch.Tensor) – The input node features of a homogeneous graph.

  • edge_index (torch.Tensor) – The input edge indices of a homogeneous graph.

  • target (torch.Tensor) – The target of the model.

  • index (int or torch.Tensor, optional) – The index of the model output to explain. Needs to be a single index. (default: None)

  • **kwargs (optional) – Additional keyword arguments passed to model.

forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs) Explanation[source]

Computes the explanation.

Parameters:
  • model (torch.nn.Module) – The model to explain.

  • x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.

  • edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.

  • target (torch.Tensor) – The target of the model.

  • index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default: None)

  • **kwargs (optional) – Additional keyword arguments passed to model.

supports() bool[source]

Checks if the explainer supports the user-defined settings provided in self.explainer_config, self.model_config.