from typing import Optional
import torch
from torch import Tensor
from torch.nn import Module
from torch_geometric.nn.inits import reset
[docs]class DenseGINConv(torch.nn.Module):
r"""See :class:`torch_geometric.nn.conv.GINConv`."""
def __init__(
self,
nn: Module,
eps: float = 0.0,
train_eps: bool = False,
):
super().__init__()
self.nn = nn
self.initial_eps = eps
if train_eps:
self.eps = torch.nn.Parameter(torch.empty(1))
else:
self.register_buffer('eps', torch.empty(1))
self.reset_parameters()
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
reset(self.nn)
self.eps.data.fill_(self.initial_eps)
[docs] def forward(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None,
add_loop: bool = True) -> Tensor:
r"""Forward pass.
Args:
x (torch.Tensor): Node feature tensor
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
each graph, and feature dimension :math:`F`.
adj (torch.Tensor): Adjacency tensor
:math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
The adjacency tensor is broadcastable in the batch dimension,
resulting in a shared adjacency matrix for the complete batch.
mask (torch.Tensor, optional): Mask matrix
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
the valid nodes for each graph. (default: :obj:`None`)
add_loop (bool, optional): If set to :obj:`False`, the layer will
not automatically add self-loops to the adjacency matrices.
(default: :obj:`True`)
"""
x = x.unsqueeze(0) if x.dim() == 2 else x
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
B, N, _ = adj.size()
out = torch.matmul(adj, x)
if add_loop:
out = (1 + self.eps) * x + out
out = self.nn(out)
if mask is not None:
out = out * mask.view(B, N, 1).to(x.dtype)
return out
def __repr__(self) -> str:
return f'{self.__class__.__name__}(nn={self.nn})'