Source code for torch_geometric.nn.norm.graph_norm

from typing import Optional

import torch
from torch import Tensor
from torch_scatter import scatter_mean

from ..inits import zeros, ones


[docs]class GraphNorm(torch.nn.Module): r"""Applies graph normalization over individual graphs as described in the `"GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training" <https://arxiv.org/abs/2009.03294>`_ paper .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]} {\sqrt{\textrm{Var}[\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]] + \epsilon}} \odot \gamma + \beta where :math:`\alpha` denotes parameters that learn how much information to keep in the mean. Args: in_channels (int): Size of each input sample. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) """ def __init__(self, in_channels: int, eps: float = 1e-5): super(GraphNorm, self).__init__() self.in_channels = in_channels self.eps = eps self.weight = torch.nn.Parameter(torch.Tensor(in_channels)) self.bias = torch.nn.Parameter(torch.Tensor(in_channels)) self.mean_scale = torch.nn.Parameter(torch.Tensor(in_channels)) self.reset_parameters()
[docs] def reset_parameters(self): ones(self.weight) zeros(self.bias) ones(self.mean_scale)
[docs] def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor: """""" if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) batch_size = int(batch.max()) + 1 mean = scatter_mean(x, batch, dim=0, dim_size=batch_size)[batch] out = x - mean * self.mean_scale var = scatter_mean(out.pow(2), batch, dim=0, dim_size=batch_size) std = (var + self.eps).sqrt()[batch] return self.weight * out / std + self.bias
def __repr__(self): return f'{self.__class__.__name__}({self.in_channels})'