Source code for torch_geometric.nn.norm.graph_norm

from typing import Optional

import torch
from torch import Tensor

from torch_geometric.nn.inits import ones, zeros
from torch_geometric.typing import OptTensor
from torch_geometric.utils import scatter

[docs]class GraphNorm(torch.nn.Module): r"""Applies graph normalization over individual graphs as described in the `"GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training" <>`_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]} {\sqrt{\textrm{Var}[\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]] + \epsilon}} \odot \gamma + \beta where :math:`\alpha` denotes parameters that learn how much information to keep in the mean. Args: in_channels (int): Size of each input sample. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) """ def __init__(self, in_channels: int, eps: float = 1e-5): super().__init__() self.in_channels = in_channels self.eps = eps self.weight = torch.nn.Parameter(torch.empty(in_channels)) self.bias = torch.nn.Parameter(torch.empty(in_channels)) self.mean_scale = torch.nn.Parameter(torch.empty(in_channels)) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" ones(self.weight) zeros(self.bias) ones(self.mean_scale)
[docs] def forward(self, x: Tensor, batch: OptTensor = None, batch_size: Optional[int] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) batch_size = 1 if batch_size is None: batch_size = int(batch.max()) + 1 mean = scatter(x, batch, 0, batch_size, reduce='mean') out = x - mean.index_select(0, batch) * self.mean_scale var = scatter(out.pow(2), batch, 0, batch_size, reduce='mean') std = (var + self.eps).sqrt().index_select(0, batch) return self.weight * out / std + self.bias
def __repr__(self): return f'{self.__class__.__name__}({self.in_channels})'