class DeepGraphInfomax(hidden_channels: int, encoder: Module, summary: Callable, corruption: Callable)[source]

Bases: Module

The Deep Graph Infomax model from the “Deep Graph Infomax” paper based on user-defined encoder and summary model \(\mathcal{E}\) and \(\mathcal{R}\) respectively, and a corruption function \(\mathcal{C}\).

  • hidden_channels (int) – The latent space dimensionality.

  • encoder (torch.nn.Module) – The encoder module \(\mathcal{E}\).

  • summary (callable) – The readout function \(\mathcal{R}\).

  • corruption (callable) – The corruption function \(\mathcal{C}\).

forward(*args, **kwargs) Tuple[Tensor, Tensor, Tensor][source]

Returns the latent space for the input arguments, their corruptions and their summary representation.


Resets all learnable parameters of the module.

discriminate(z: Tensor, summary: Tensor, sigmoid: bool = True) Tensor[source]

Given the patch-summary pair z and summary, computes the probability scores assigned to this patch-summary pair.

  • z (torch.Tensor) – The latent space.

  • summary (torch.Tensor) – The summary vector.

  • sigmoid (bool, optional) – If set to False, does not apply the logistic sigmoid function to the output. (default: True)

loss(pos_z: Tensor, neg_z: Tensor, summary: Tensor) Tensor[source]

Computes the mutual information maximization objective.

test(train_z: Tensor, train_y: Tensor, test_z: Tensor, test_y: Tensor, solver: str = 'lbfgs', multi_class: str = 'auto', *args, **kwargs) float[source]

Evaluates latent space quality via a logistic regression downstream task.