Source code for torch_geometric.nn.dense.dense_gin_conv

from typing import Optional

import torch
from torch import Tensor
from torch.nn import Module

from torch_geometric.nn.inits import reset


[docs]class DenseGINConv(torch.nn.Module): r"""See :class:`torch_geometric.nn.conv.GINConv`.""" def __init__( self, nn: Module, eps: float = 0.0, train_eps: bool = False, ): super().__init__() self.nn = nn self.initial_eps = eps if train_eps: self.eps = torch.nn.Parameter(torch.empty(1)) else: self.register_buffer('eps', torch.empty(1)) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset(self.nn) self.eps.data.fill_(self.initial_eps)
[docs] def forward(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None, add_loop: bool = True) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. adj (torch.Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) add_loop (bool, optional): If set to :obj:`False`, the layer will not automatically add self-loops to the adjacency matrices. (default: :obj:`True`) """ x = x.unsqueeze(0) if x.dim() == 2 else x adj = adj.unsqueeze(0) if adj.dim() == 2 else adj B, N, _ = adj.size() out = torch.matmul(adj, x) if add_loop: out = (1 + self.eps) * x + out out = self.nn(out) if mask is not None: out = out * mask.view(B, N, 1).to(x.dtype) return out
def __repr__(self) -> str: return f'{self.__class__.__name__}(nn={self.nn})'