Source code for torch_geometric.nn.models.mlp

from typing import List

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import BatchNorm1d, Identity

from torch_geometric.nn.dense.linear import Linear


[docs]class MLP(torch.nn.Module): r"""A multi-layer perception (MLP) model. Args: channel_list (List[int]): List of input, intermediate and output channels. :obj:`len(channel_list) - 1` denotes the number of layers of the MLP. dropout (float, optional): Dropout probability of each hidden embedding. (default: :obj:`0.`) batch_norm (bool, optional): If set to :obj:`False`, will not make use of batch normalization. (default: :obj:`True`) relu_first (bool, optional): If set to :obj:`True`, ReLU activation is applied before batch normalization. (default: :obj:`False`) """ def __init__(self, channel_list: List[int], dropout: float = 0., batch_norm: bool = True, relu_first: bool = False): super().__init__() assert len(channel_list) >= 2 self.channel_list = channel_list self.dropout = dropout self.relu_first = relu_first self.lins = torch.nn.ModuleList() for dims in zip(channel_list[:-1], channel_list[1:]): self.lins.append(Linear(*dims)) self.norms = torch.nn.ModuleList() for dim in zip(channel_list[1:-1]): self.norms.append(BatchNorm1d(dim) if batch_norm else Identity()) self.reset_parameters()
[docs] def reset_parameters(self): for lin in self.lins: lin.reset_parameters() for norm in self.norms: if hasattr(norm, 'reset_parameters'): norm.reset_parameters()
[docs] def forward(self, x: Tensor) -> Tensor: """""" x = self.lins[0](x) for lin, norm in zip(self.lins[1:], self.norms): if self.relu_first: x = x.relu_() x = norm(x) if not self.relu_first: x = x.relu_() x = F.dropout(x, p=self.dropout, training=self.training) x = lin.forward(x) return x
def __repr__(self) -> str: return f'{self.__class__.__name__}({str(self.channel_list)[1:-1]})'