torch_geometric.nn.models.DeepGraphInfomax
- class DeepGraphInfomax(hidden_channels: int, encoder: Module, summary: Callable, corruption: Callable)[source]
Bases:
Module
The Deep Graph Infomax model from the “Deep Graph Infomax” paper based on user-defined encoder and summary model \(\mathcal{E}\) and \(\mathcal{R}\) respectively, and a corruption function \(\mathcal{C}\).
- Parameters:
hidden_channels (int) – The latent space dimensionality.
encoder (torch.nn.Module) – The encoder module \(\mathcal{E}\).
summary (callable) – The readout function \(\mathcal{R}\).
corruption (callable) – The corruption function \(\mathcal{C}\).
- forward(*args, **kwargs) Tuple[Tensor, Tensor, Tensor] [source]
Returns the latent space for the input arguments, their corruptions and their summary representation.
- discriminate(z: Tensor, summary: Tensor, sigmoid: bool = True) Tensor [source]
Given the patch-summary pair
z
andsummary
, computes the probability scores assigned to this patch-summary pair.- Parameters:
z (torch.Tensor) – The latent space.
summary (torch.Tensor) – The summary vector.
sigmoid (bool, optional) – If set to
False
, does not apply the logistic sigmoid function to the output. (default:True
)