Source code for torch_geometric.nn.norm.batch_norm

import torch
from torch import Tensor


[docs]class BatchNorm(torch.nn.Module): r"""Applies batch normalization over a batch of node features as described in the `"Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" <https://arxiv.org/abs/1502.03167>`_ paper .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta The mean and standard-deviation are calculated per-dimension over all nodes inside the mini-batch. Args: in_channels (int): Size of each input sample. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) momentum (float, optional): The value used for the running mean and running variance computation. (default: :obj:`0.1`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) track_running_stats (bool, optional): If set to :obj:`True`, this module tracks the running mean and variance, and when set to :obj:`False`, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: :obj:`True`) """ def __init__(self, in_channels, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(BatchNorm, self).__init__() self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine, track_running_stats)
[docs] def reset_parameters(self): self.module.reset_parameters()
[docs] def forward(self, x: Tensor) -> Tensor: """""" return self.module(x)
def __repr__(self): return f'{self.__class__.__name__}({self.module.num_features})'