import copy
from typing import Callable, Tuple
import torch
from torch import Tensor
from torch.nn import Module, Parameter
from torch_geometric.nn.inits import reset, uniform
EPS = 1e-15
[docs]class DeepGraphInfomax(torch.nn.Module):
r"""The Deep Graph Infomax model from the
`"Deep Graph Infomax" <https://arxiv.org/abs/1809.10341>`_
paper based on user-defined encoder and summary model :math:`\mathcal{E}`
and :math:`\mathcal{R}` respectively, and a corruption function
:math:`\mathcal{C}`.
Args:
hidden_channels (int): The latent space dimensionality.
encoder (torch.nn.Module): The encoder module :math:`\mathcal{E}`.
summary (callable): The readout function :math:`\mathcal{R}`.
corruption (callable): The corruption function :math:`\mathcal{C}`.
"""
def __init__(
self,
hidden_channels: int,
encoder: Module,
summary: Callable,
corruption: Callable,
):
super().__init__()
self.hidden_channels = hidden_channels
self.encoder = encoder
self.summary = summary
self.corruption = corruption
self.weight = Parameter(torch.empty(hidden_channels, hidden_channels))
self.reset_parameters()
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
reset(self.encoder)
reset(self.summary)
uniform(self.hidden_channels, self.weight)
[docs] def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
"""Returns the latent space for the input arguments, their
corruptions and their summary representation.
"""
pos_z = self.encoder(*args, **kwargs)
cor = self.corruption(*args, **kwargs)
cor = cor if isinstance(cor, tuple) else (cor, )
cor_args = cor[:len(args)]
cor_kwargs = copy.copy(kwargs)
for key, value in zip(kwargs.keys(), cor[len(args):]):
cor_kwargs[key] = value
neg_z = self.encoder(*cor_args, **cor_kwargs)
summary = self.summary(pos_z, *args, **kwargs)
return pos_z, neg_z, summary
[docs] def discriminate(self, z: Tensor, summary: Tensor,
sigmoid: bool = True) -> Tensor:
r"""Given the patch-summary pair :obj:`z` and :obj:`summary`, computes
the probability scores assigned to this patch-summary pair.
Args:
z (torch.Tensor): The latent space.
summary (torch.Tensor): The summary vector.
sigmoid (bool, optional): If set to :obj:`False`, does not apply
the logistic sigmoid function to the output.
(default: :obj:`True`)
"""
summary = summary.t() if summary.dim() > 1 else summary
value = torch.matmul(z, torch.matmul(self.weight, summary))
return torch.sigmoid(value) if sigmoid else value
[docs] def loss(self, pos_z: Tensor, neg_z: Tensor, summary: Tensor) -> Tensor:
r"""Computes the mutual information maximization objective."""
pos_loss = -torch.log(
self.discriminate(pos_z, summary, sigmoid=True) + EPS).mean()
neg_loss = -torch.log(1 -
self.discriminate(neg_z, summary, sigmoid=True) +
EPS).mean()
return pos_loss + neg_loss
[docs] def test(
self,
train_z: Tensor,
train_y: Tensor,
test_z: Tensor,
test_y: Tensor,
solver: str = 'lbfgs',
*args,
**kwargs,
) -> float:
r"""Evaluates latent space quality via a logistic regression downstream
task.
"""
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(solver=solver, *args,
**kwargs).fit(train_z.detach().cpu().numpy(),
train_y.detach().cpu().numpy())
return clf.score(test_z.detach().cpu().numpy(),
test_y.detach().cpu().numpy())
def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.hidden_channels})'