Source code for torch_geometric.nn.norm.layer_norm

from typing import List, Optional, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.inits import ones, zeros
from torch_geometric.typing import OptTensor
from torch_geometric.utils import degree, scatter


[docs]class LayerNorm(torch.nn.Module): r"""Applies layer normalization over each individual example in a batch of features as described in the `"Layer Normalization" <https://arxiv.org/abs/1607.06450>`_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta The mean and standard-deviation are calculated across all nodes and all node channels separately for each object in a mini-batch. Args: in_channels (int): Size of each input sample. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) mode (str, optinal): The normalization mode to use for layer normalization (:obj:`"graph"` or :obj:`"node"`). If :obj:`"graph"` is used, each graph will be considered as an element to be normalized. If `"node"` is used, each node will be considered as an element to be normalized. (default: :obj:`"graph"`) """ def __init__( self, in_channels: int, eps: float = 1e-5, affine: bool = True, mode: str = 'graph', ): super().__init__() self.in_channels = in_channels self.eps = eps self.affine = affine self.mode = mode if affine: self.weight = Parameter(torch.empty(in_channels)) self.bias = Parameter(torch.empty(in_channels)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" ones(self.weight) zeros(self.bias)
[docs] def forward(self, x: Tensor, batch: OptTensor = None, batch_size: Optional[int] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ if self.mode == 'graph': if batch is None: x = x - x.mean() out = x / (x.std(unbiased=False) + self.eps) else: if batch_size is None: batch_size = int(batch.max()) + 1 norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1) norm = norm.mul_(x.size(-1)).view(-1, 1) mean = scatter(x, batch, dim=0, dim_size=batch_size, reduce='sum').sum(dim=-1, keepdim=True) / norm x = x - mean.index_select(0, batch) var = scatter(x * x, batch, dim=0, dim_size=batch_size, reduce='sum').sum(dim=-1, keepdim=True) var = var / norm out = x / (var + self.eps).sqrt().index_select(0, batch) if self.weight is not None and self.bias is not None: out = out * self.weight + self.bias return out if self.mode == 'node': return F.layer_norm(x, (self.in_channels, ), self.weight, self.bias, self.eps) raise ValueError(f"Unknow normalization mode: {self.mode}")
def __repr__(self): return (f'{self.__class__.__name__}({self.in_channels}, ' f'affine={self.affine}, mode={self.mode})')
[docs]class HeteroLayerNorm(torch.nn.Module): r"""Applies layer normalization over each individual example in a batch of heterogeneous features as described in the `"Layer Normalization" <https://arxiv.org/abs/1607.06450>`_ paper. Compared to :class:`LayerNorm`, :class:`HeteroLayerNorm` applies normalization individually for each node or edge type. Args: in_channels (int): Size of each input sample. num_types (int): The number of types. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) mode (str, optinal): The normalization mode to use for layer normalization (:obj:`"node"`). If `"node"` is used, each node will be considered as an element to be normalized. (default: :obj:`"node"`) """ def __init__( self, in_channels: int, num_types: int, eps: float = 1e-5, affine: bool = True, mode: str = 'node', ): super().__init__() assert mode == 'node' self.in_channels = in_channels self.num_types = num_types self.eps = eps self.affine = affine if affine: self.weight = Parameter(torch.empty(num_types, in_channels)) self.bias = Parameter(torch.empty(num_types, in_channels)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" if self.affine: torch.nn.init.ones_(self.weight) torch.nn.init.zeros_(self.bias)
[docs] def forward( self, x: Tensor, type_vec: OptTensor = None, type_ptr: Optional[Union[Tensor, List[int]]] = None, ) -> Tensor: r"""Forward pass. .. note:: Either :obj:`type_vec` or :obj:`type_ptr` needs to be specified. In general, relying on :obj:`type_ptr` is more efficient in case the input tensor is sorted by types. Args: x (torch.Tensor): The input features. type_vec (torch.Tensor, optional): A vector that maps each entry to a type. (default: :obj:`None`) type_ptr (torch.Tensor or List[int]): A vector denoting the boundaries of types. (default: :obj:`None`) """ if type_vec is None and type_ptr is None: raise ValueError("Either 'type_vec' or 'type_ptr' must be given") out = F.layer_norm(x, (self.in_channels, ), None, None, self.eps) if self.affine: # TODO Revisit this logic completely as it performs worse than just # operating on a dictionary of tensors # (especially the `type_vec` code path) if type_ptr is not None: h = torch.empty_like(out) for i, (s, e) in enumerate(zip(type_ptr[:-1], type_ptr[1:])): h[s:e] = out[s:e] * self.weight[i] + self.bias[i] out = h else: out = out * self.weight[type_vec] + self.bias[type_vec] return out
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'num_types={self.num_types})')