Source code for torch_geometric.nn.norm.batch_norm

from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.aggr.fused import FusedAggregation


[docs]class BatchNorm(torch.nn.Module): r"""Applies batch normalization over a batch of features as described in the `"Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" <https://arxiv.org/abs/1502.03167>`_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta The mean and standard-deviation are calculated per-dimension over all nodes inside the mini-batch. Args: in_channels (int): Size of each input sample. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) momentum (float, optional): The value used for the running mean and running variance computation. (default: :obj:`0.1`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) track_running_stats (bool, optional): If set to :obj:`True`, this module tracks the running mean and variance, and when set to :obj:`False`, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: :obj:`True`) allow_single_element (bool, optional): If set to :obj:`True`, batches with only a single element will work as during in evaluation. That is the running mean and variance will be used. Requires :obj:`track_running_stats=True`. (default: :obj:`False`) """ def __init__( self, in_channels: int, eps: float = 1e-5, momentum: Optional[float] = 0.1, affine: bool = True, track_running_stats: bool = True, allow_single_element: bool = False, ): super().__init__() if allow_single_element and not track_running_stats: raise ValueError("'allow_single_element' requires " "'track_running_stats' to be set to `True`") self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine, track_running_stats) self.in_channels = in_channels self.allow_single_element = allow_single_element
[docs] def reset_running_stats(self): r"""Resets all running statistics of the module.""" self.module.reset_running_stats()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.module.reset_parameters()
[docs] def forward(self, x: Tensor) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. """ if self.allow_single_element and x.size(0) <= 1: return torch.nn.functional.batch_norm( x, self.module.running_mean, self.module.running_var, self.module.weight, self.module.bias, False, # bn_training 0.0, # momentum self.module.eps, ) return self.module(x)
def __repr__(self): return f'{self.__class__.__name__}({self.module.num_features})'
[docs]class HeteroBatchNorm(torch.nn.Module): r"""Applies batch normalization over a batch of heterogeneous features as described in the `"Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" <https://arxiv.org/abs/1502.03167>`_ paper. Compared to :class:`BatchNorm`, :class:`HeteroBatchNorm` applies normalization individually for each node or edge type. Args: in_channels (int): Size of each input sample. num_types (int): The number of types. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) momentum (float, optional): The value used for the running mean and running variance computation. (default: :obj:`0.1`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) track_running_stats (bool, optional): If set to :obj:`True`, this module tracks the running mean and variance, and when set to :obj:`False`, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: :obj:`True`) """ def __init__( self, in_channels: int, num_types: int, eps: float = 1e-5, momentum: Optional[float] = 0.1, affine: bool = True, track_running_stats: bool = True, ): super().__init__() self.in_channels = in_channels self.num_types = num_types self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: self.weight = Parameter(torch.empty(num_types, in_channels)) self.bias = Parameter(torch.empty(num_types, in_channels)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.register_buffer('running_mean', torch.empty(num_types, in_channels)) self.register_buffer('running_var', torch.empty(num_types, in_channels)) self.register_buffer('num_batches_tracked', torch.tensor(0)) else: self.register_buffer('running_mean', None) self.register_buffer('running_var', None) self.register_buffer('num_batches_tracked', None) self.mean_var = FusedAggregation(['mean', 'var']) self.reset_parameters()
[docs] def reset_running_stats(self): r"""Resets all running statistics of the module.""" if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) self.num_batches_tracked.zero_()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.reset_running_stats() if self.affine: torch.nn.init.ones_(self.weight) torch.nn.init.zeros_(self.bias)
[docs] def forward(self, x: Tensor, type_vec: Tensor) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The input features. type_vec (torch.Tensor): A vector that maps each entry to a type. """ if not self.training and self.track_running_stats: mean, var = self.running_mean, self.running_var else: with torch.no_grad(): mean, var = self.mean_var(x, type_vec, dim_size=self.num_types) if self.training and self.track_running_stats: if self.momentum is None: self.num_batches_tracked.add_(1) exp_avg_factor = 1.0 / float(self.num_batches_tracked) else: exp_avg_factor = self.momentum with torch.no_grad(): # Update running mean and variance: type_index = torch.unique(type_vec) self.running_mean[type_index] = ( (1.0 - exp_avg_factor) * self.running_mean[type_index] + exp_avg_factor * mean[type_index]) self.running_var[type_index] = ( (1.0 - exp_avg_factor) * self.running_var[type_index] + exp_avg_factor * var[type_index]) out = (x - mean[type_vec]) / var.clamp(self.eps).sqrt()[type_vec] if self.affine: out = out * self.weight[type_vec] + self.bias[type_vec] return out
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'num_types={self.num_types})')