torch_geometric.nn.models.GAE

class GAE(encoder: Module, decoder: Optional[Module] = None)[source]

Bases: Module

The Graph Auto-Encoder model from the “Variational Graph Auto-Encoders” paper based on user-defined encoder and decoder models.

Parameters:
forward(*args, **kwargs) Tensor[source]

Alias for encode().

reset_parameters()[source]

Resets all learnable parameters of the module.

encode(*args, **kwargs) Tensor[source]

Runs the encoder and computes node-wise latent variables.

decode(*args, **kwargs) Tensor[source]

Runs the decoder and computes edge probabilities.

recon_loss(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Optional[Tensor] = None) Tensor[source]

Given latent variables z, computes the binary cross entropy loss for positive edges pos_edge_index and negative sampled edges.

Parameters:
  • z (torch.Tensor) – The latent space \(\mathbf{Z}\).

  • pos_edge_index (torch.Tensor) – The positive edges to train against.

  • neg_edge_index (torch.Tensor, optional) – The negative edges to train against. If not given, uses negative sampling to calculate negative edges. (default: None)

test(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor) Tuple[Tensor, Tensor][source]

Given latent variables z, positive edges pos_edge_index and negative edges neg_edge_index, computes area under the ROC curve (AUC) and average precision (AP) scores.

Parameters:
  • z (torch.Tensor) – The latent space \(\mathbf{Z}\).

  • pos_edge_index (torch.Tensor) – The positive edges to evaluate against.

  • neg_edge_index (torch.Tensor) – The negative edges to evaluate against.