import warnings
from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Identity
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.resolver import (
activation_resolver,
normalization_resolver,
)
from torch_geometric.typing import NoneType
[docs]class MLP(torch.nn.Module):
r"""A Multi-Layer Perception (MLP) model.
There exists two ways to instantiate an :class:`MLP`:
1. By specifying explicit channel sizes, *e.g.*,
.. code-block:: python
mlp = MLP([16, 32, 64, 128])
creates a three-layer MLP with **differently** sized hidden layers.
1. By specifying fixed hidden channel sizes over a number of layers,
*e.g.*,
.. code-block:: python
mlp = MLP(in_channels=16, hidden_channels=32,
out_channels=128, num_layers=3)
creates a three-layer MLP with **equally** sized hidden layers.
Args:
channel_list (List[int] or int, optional): List of input, intermediate
and output channels such that :obj:`len(channel_list) - 1` denotes
the number of layers of the MLP (default: :obj:`None`)
in_channels (int, optional): Size of each input sample.
Will override :attr:`channel_list`. (default: :obj:`None`)
hidden_channels (int, optional): Size of each hidden sample.
Will override :attr:`channel_list`. (default: :obj:`None`)
out_channels (int, optional): Size of each output sample.
Will override :attr:`channel_list`. (default: :obj:`None`)
num_layers (int, optional): The number of layers.
Will override :attr:`channel_list`. (default: :obj:`None`)
dropout (float or List[float], optional): Dropout probability of each
hidden embedding. If a list is provided, sets the dropout value per
layer. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`"batch_norm"`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
plain_last (bool, optional): If set to :obj:`False`, will apply
non-linearity, batch normalization and dropout to the last layer as
well. (default: :obj:`True`)
bias (bool or List[bool], optional): If set to :obj:`False`, the module
will not learn additive biases. If a list is provided, sets the
bias per layer. (default: :obj:`True`)
**kwargs (optional): Additional deprecated arguments of the MLP layer.
"""
def __init__(
self,
channel_list: Optional[Union[List[int], int]] = None,
*,
in_channels: Optional[int] = None,
hidden_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: Optional[int] = None,
dropout: Union[float, List[float]] = 0.,
act: Union[str, Callable, None] = "relu",
act_first: bool = False,
act_kwargs: Optional[Dict[str, Any]] = None,
norm: Union[str, Callable, None] = "batch_norm",
norm_kwargs: Optional[Dict[str, Any]] = None,
plain_last: bool = True,
bias: Union[bool, List[bool]] = True,
**kwargs,
):
super().__init__()
# Backward compatibility:
act_first = act_first or kwargs.get("relu_first", False)
batch_norm = kwargs.get("batch_norm", None)
if batch_norm is not None and isinstance(batch_norm, bool):
warnings.warn("Argument `batch_norm` is deprecated, "
"please use `norm` to specify normalization layer.")
norm = 'batch_norm' if batch_norm else None
batch_norm_kwargs = kwargs.get("batch_norm_kwargs", None)
norm_kwargs = batch_norm_kwargs or {}
if isinstance(channel_list, int):
in_channels = channel_list
if in_channels is not None:
assert num_layers >= 1
channel_list = [hidden_channels] * (num_layers - 1)
channel_list = [in_channels] + channel_list + [out_channels]
assert isinstance(channel_list, (tuple, list))
assert len(channel_list) >= 2
self.channel_list = channel_list
self.act = activation_resolver(act, **(act_kwargs or {}))
self.act_first = act_first
self.plain_last = plain_last
if isinstance(dropout, float):
dropout = [dropout] * (len(channel_list) - 1)
if plain_last:
dropout[-1] = 0.
if len(dropout) != len(channel_list) - 1:
raise ValueError(
f"Number of dropout values provided ({len(dropout)} does not "
f"match the number of layers specified "
f"({len(channel_list)-1})")
self.dropout = dropout
if isinstance(bias, bool):
bias = [bias] * (len(channel_list) - 1)
if len(bias) != len(channel_list) - 1:
raise ValueError(
f"Number of bias values provided ({len(bias)}) does not match "
f"the number of layers specified ({len(channel_list)-1})")
self.lins = torch.nn.ModuleList()
iterator = zip(channel_list[:-1], channel_list[1:], bias)
for in_channels, out_channels, _bias in iterator:
self.lins.append(Linear(in_channels, out_channels, bias=_bias))
self.norms = torch.nn.ModuleList()
iterator = channel_list[1:-1] if plain_last else channel_list[1:]
for hidden_channels in iterator:
if norm is not None:
norm_layer = normalization_resolver(
norm,
hidden_channels,
**(norm_kwargs or {}),
)
else:
norm_layer = Identity()
self.norms.append(norm_layer)
self.reset_parameters()
@property
def in_channels(self) -> int:
r"""Size of each input sample."""
return self.channel_list[0]
@property
def out_channels(self) -> int:
r"""Size of each output sample."""
return self.channel_list[-1]
@property
def num_layers(self) -> int:
r"""The number of layers."""
return len(self.channel_list) - 1
[docs] def reset_parameters(self):
for lin in self.lins:
lin.reset_parameters()
for norm in self.norms:
if hasattr(norm, 'reset_parameters'):
norm.reset_parameters()
[docs] def forward(self, x: Tensor, return_emb: NoneType = None) -> Tensor:
""""""
for i, (lin, norm) in enumerate(zip(self.lins, self.norms)):
x = lin(x)
if self.act is not None and self.act_first:
x = self.act(x)
x = norm(x)
if self.act is not None and not self.act_first:
x = self.act(x)
x = F.dropout(x, p=self.dropout[i], training=self.training)
emb = x
if self.plain_last:
x = self.lins[-1](x)
x = F.dropout(x, p=self.dropout[-1], training=self.training)
return (x, emb) if isinstance(return_emb, bool) else x
def __repr__(self) -> str:
return f'{self.__class__.__name__}({str(self.channel_list)[1:-1]})'