Source code for torch_geometric.nn.models.mlp

import inspect
import warnings
from typing import Any, Callable, Dict, Final, List, Optional, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Identity

from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import NoneType


[docs]class MLP(torch.nn.Module): r"""A Multi-Layer Perception (MLP) model. There exists two ways to instantiate an :class:`MLP`: 1. By specifying explicit channel sizes, *e.g.*, .. code-block:: python mlp = MLP([16, 32, 64, 128]) creates a three-layer MLP with **differently** sized hidden layers. 1. By specifying fixed hidden channel sizes over a number of layers, *e.g.*, .. code-block:: python mlp = MLP(in_channels=16, hidden_channels=32, out_channels=128, num_layers=3) creates a three-layer MLP with **equally** sized hidden layers. Args: channel_list (List[int] or int, optional): List of input, intermediate and output channels such that :obj:`len(channel_list) - 1` denotes the number of layers of the MLP (default: :obj:`None`) in_channels (int, optional): Size of each input sample. Will override :attr:`channel_list`. (default: :obj:`None`) hidden_channels (int, optional): Size of each hidden sample. Will override :attr:`channel_list`. (default: :obj:`None`) out_channels (int, optional): Size of each output sample. Will override :attr:`channel_list`. (default: :obj:`None`) num_layers (int, optional): The number of layers. Will override :attr:`channel_list`. (default: :obj:`None`) dropout (float or List[float], optional): Dropout probability of each hidden embedding. If a list is provided, sets the dropout value per layer. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`"batch_norm"`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) plain_last (bool, optional): If set to :obj:`False`, will apply non-linearity, batch normalization and dropout to the last layer as well. (default: :obj:`True`) bias (bool or List[bool], optional): If set to :obj:`False`, the module will not learn additive biases. If a list is provided, sets the bias per layer. (default: :obj:`True`) **kwargs (optional): Additional deprecated arguments of the MLP layer. """ supports_norm_batch: Final[bool] def __init__( self, channel_list: Optional[Union[List[int], int]] = None, *, in_channels: Optional[int] = None, hidden_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: Optional[int] = None, dropout: Union[float, List[float]] = 0., act: Union[str, Callable, None] = "relu", act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Union[str, Callable, None] = "batch_norm", norm_kwargs: Optional[Dict[str, Any]] = None, plain_last: bool = True, bias: Union[bool, List[bool]] = True, **kwargs, ): super().__init__() # Backward compatibility: act_first = act_first or kwargs.get("relu_first", False) batch_norm = kwargs.get("batch_norm", None) if batch_norm is not None and isinstance(batch_norm, bool): warnings.warn("Argument `batch_norm` is deprecated, " "please use `norm` to specify normalization layer.") norm = 'batch_norm' if batch_norm else None batch_norm_kwargs = kwargs.get("batch_norm_kwargs", None) norm_kwargs = batch_norm_kwargs or {} if isinstance(channel_list, int): in_channels = channel_list if in_channels is not None: if num_layers is None: raise ValueError("Argument `num_layers` must be given") if num_layers > 1 and hidden_channels is None: raise ValueError(f"Argument `hidden_channels` must be given " f"for `num_layers={num_layers}`") if out_channels is None: raise ValueError("Argument `out_channels` must be given") channel_list = [hidden_channels] * (num_layers - 1) channel_list = [in_channels] + channel_list + [out_channels] assert isinstance(channel_list, (tuple, list)) assert len(channel_list) >= 2 self.channel_list = channel_list self.act = activation_resolver(act, **(act_kwargs or {})) self.act_first = act_first self.plain_last = plain_last if isinstance(dropout, float): dropout = [dropout] * (len(channel_list) - 1) if plain_last: dropout[-1] = 0. if len(dropout) != len(channel_list) - 1: raise ValueError( f"Number of dropout values provided ({len(dropout)} does not " f"match the number of layers specified " f"({len(channel_list)-1})") self.dropout = dropout if isinstance(bias, bool): bias = [bias] * (len(channel_list) - 1) if len(bias) != len(channel_list) - 1: raise ValueError( f"Number of bias values provided ({len(bias)}) does not match " f"the number of layers specified ({len(channel_list)-1})") self.lins = torch.nn.ModuleList() iterator = zip(channel_list[:-1], channel_list[1:], bias) for in_channels, out_channels, _bias in iterator: self.lins.append(Linear(in_channels, out_channels, bias=_bias)) self.norms = torch.nn.ModuleList() iterator = channel_list[1:-1] if plain_last else channel_list[1:] for hidden_channels in iterator: if norm is not None: norm_layer = normalization_resolver( norm, hidden_channels, **(norm_kwargs or {}), ) else: norm_layer = Identity() self.norms.append(norm_layer) self.supports_norm_batch = False if len(self.norms) > 0 and hasattr(self.norms[0], 'forward'): norm_params = inspect.signature(self.norms[0].forward).parameters self.supports_norm_batch = 'batch' in norm_params self.reset_parameters() @property def in_channels(self) -> int: r"""Size of each input sample.""" return self.channel_list[0] @property def out_channels(self) -> int: r"""Size of each output sample.""" return self.channel_list[-1] @property def num_layers(self) -> int: r"""The number of layers.""" return len(self.channel_list) - 1
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for lin in self.lins: lin.reset_parameters() for norm in self.norms: if hasattr(norm, 'reset_parameters'): norm.reset_parameters()
[docs] def forward( self, x: Tensor, batch: Optional[Tensor] = None, batch_size: Optional[int] = None, return_emb: NoneType = None, ) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. Only needs to be passed in case the underlying normalization layers require the :obj:`batch` information. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. Only needs to be passed in case the underlying normalization layers require the :obj:`batch` information. (default: :obj:`None`) return_emb (bool, optional): If set to :obj:`True`, will additionally return the embeddings before execution of the final output layer. (default: :obj:`False`) """ # `return_emb` is annotated here as `NoneType` to be compatible with # TorchScript, which does not support different return types based on # the value of an input argument. emb: Optional[Tensor] = None # If `plain_last=True`, then `len(norms) = len(lins) -1, thus skipping # the execution of the last layer inside the for-loop. for i, (lin, norm) in enumerate(zip(self.lins, self.norms)): x = lin(x) if self.act is not None and self.act_first: x = self.act(x) if self.supports_norm_batch: x = norm(x, batch, batch_size) else: x = norm(x) if self.act is not None and not self.act_first: x = self.act(x) x = F.dropout(x, p=self.dropout[i], training=self.training) if isinstance(return_emb, bool) and return_emb is True: emb = x if self.plain_last: x = self.lins[-1](x) x = F.dropout(x, p=self.dropout[-1], training=self.training) return (x, emb) if isinstance(return_emb, bool) else x
def __repr__(self) -> str: return f'{self.__class__.__name__}({str(self.channel_list)[1:-1]})'