torch_geometric.nn.models.GAE
- class GAE(encoder: Module, decoder: Optional[Module] = None)[source]
Bases:
Module
The Graph Auto-Encoder model from the “Variational Graph Auto-Encoders” paper based on user-defined encoder and decoder models.
- Parameters:
encoder (torch.nn.Module) – The encoder module.
decoder (torch.nn.Module, optional) – The decoder module. If set to
None
, will default to thetorch_geometric.nn.models.InnerProductDecoder
. (default:None
)
- encode(*args, **kwargs) Tensor [source]
Runs the encoder and computes node-wise latent variables.
- Return type:
- decode(*args, **kwargs) Tensor [source]
Runs the decoder and computes edge probabilities.
- Return type:
- recon_loss(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Optional[Tensor] = None) Tensor [source]
Given latent variables
z
, computes the binary cross entropy loss for positive edgespos_edge_index
and negative sampled edges.- Parameters:
z (torch.Tensor) – The latent space \(\mathbf{Z}\).
pos_edge_index (torch.Tensor) – The positive edges to train against.
neg_edge_index (torch.Tensor, optional) – The negative edges to train against. If not given, uses negative sampling to calculate negative edges. (default:
None
)
- Return type:
- test(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor) Tuple[Tensor, Tensor] [source]
Given latent variables
z
, positive edgespos_edge_index
and negative edgesneg_edge_index
, computes area under the ROC curve (AUC) and average precision (AP) scores.- Parameters:
z (torch.Tensor) – The latent space \(\mathbf{Z}\).
pos_edge_index (torch.Tensor) – The positive edges to evaluate against.
neg_edge_index (torch.Tensor) – The negative edges to evaluate against.
- Return type: