from typing import List
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import BatchNorm1d, Identity
from torch_geometric.nn.dense.linear import Linear
[docs]class MLP(torch.nn.Module):
r"""A multi-layer perception (MLP) model.
Args:
channel_list (List[int]): List of input, intermediate and output
channels. :obj:`len(channel_list) - 1` denotes the number of layers
of the MLP.
dropout (float, optional): Dropout probability of each hidden
embedding. (default: :obj:`0.`)
batch_norm (bool, optional): If set to :obj:`False`, will not make use
of batch normalization. (default: :obj:`True`)
relu_first (bool, optional): If set to :obj:`True`, ReLU activation is
applied before batch normalization. (default: :obj:`False`)
"""
def __init__(self, channel_list: List[int], dropout: float = 0.,
batch_norm: bool = True, relu_first: bool = False):
super().__init__()
assert len(channel_list) >= 2
self.channel_list = channel_list
self.dropout = dropout
self.relu_first = relu_first
self.lins = torch.nn.ModuleList()
for dims in zip(channel_list[:-1], channel_list[1:]):
self.lins.append(Linear(*dims))
self.norms = torch.nn.ModuleList()
for dim in zip(channel_list[1:-1]):
self.norms.append(BatchNorm1d(dim) if batch_norm else Identity())
self.reset_parameters()
[docs] def reset_parameters(self):
for lin in self.lins:
lin.reset_parameters()
for norm in self.norms:
if hasattr(norm, 'reset_parameters'):
norm.reset_parameters()
[docs] def forward(self, x: Tensor) -> Tensor:
""""""
x = self.lins[0](x)
for lin, norm in zip(self.lins[1:], self.norms):
if self.relu_first:
x = x.relu_()
x = norm(x)
if not self.relu_first:
x = x.relu_()
x = F.dropout(x, p=self.dropout, training=self.training)
x = lin.forward(x)
return x
def __repr__(self) -> str:
return f'{self.__class__.__name__}({str(self.channel_list)[1:-1]})'