Source code for torch_geometric.nn.models.pmlp

from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn import SimpleConv
from torch_geometric.nn.dense.linear import Linear

[docs]class PMLP(torch.nn.Module): r"""The P(ropagational)MLP model from the `"Graph Neural Networks are Inherently Good Generalizers: Insights by Bridging GNNs and MLPs" <>`_ paper. :class:`PMLP` is identical to a standard MLP during training, but then adopts a GNN architecture during testing. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. out_channels (int): Size of each output sample. num_layers (int): The number of layers. dropout (float, optional): Dropout probability of each hidden embedding. (default: :obj:`0.`) norm (bool, optional): If set to :obj:`False`, will not apply batch normalization. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the module will not learn additive biases. (default: :obj:`True`) """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int, dropout: float = 0., norm: bool = True, bias: bool = True, ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.num_layers = num_layers self.dropout = dropout self.bias = bias self.lins = torch.nn.ModuleList() self.lins.append(Linear(in_channels, hidden_channels, self.bias)) for _ in range(self.num_layers - 2): lin = Linear(hidden_channels, hidden_channels, self.bias) self.lins.append(lin) self.lins.append(Linear(hidden_channels, out_channels, self.bias)) self.norm = None if norm: self.norm = torch.nn.BatchNorm1d( hidden_channels, affine=False, track_running_stats=False, ) self.conv = SimpleConv(aggr='mean', combine_root='self_loop') self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for lin in self.lins: torch.nn.init.xavier_uniform_(lin.weight, gain=1.414) if self.bias: torch.nn.init.zeros_(lin.bias)
[docs] def forward( self, x: torch.Tensor, edge_index: Optional[Tensor] = None, ) -> torch.Tensor: """""" # noqa: D419 if not and edge_index is None: raise ValueError(f"'edge_index' needs to be present during " f"inference in '{self.__class__.__name__}'") for i in range(self.num_layers): x = x @ self.lins[i].weight.t() if not x = self.conv(x, edge_index) if self.bias: x = x + self.lins[i].bias if i != self.num_layers - 1: if self.norm is not None: x = self.norm(x) x = x.relu() x = F.dropout(x, p=self.dropout, return x
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_layers={self.num_layers})')