from typing import Optional
import torch
from torch import Tensor
from torch_scatter import scatter_mean
from ..inits import zeros, ones
[docs]class GraphNorm(torch.nn.Module):
r"""Applies graph normalization over individual graphs as described in the
`"GraphNorm: A Principled Approach to Accelerating Graph Neural Network
Training" <https://arxiv.org/abs/2009.03294>`_ paper
.. math::
\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \alpha \odot
\textrm{E}[\mathbf{x}]}
{\sqrt{\textrm{Var}[\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]]
+ \epsilon}} \odot \gamma + \beta
where :math:`\alpha` denotes parameters that learn how much information
to keep in the mean.
Args:
in_channels (int): Size of each input sample.
eps (float, optional): A value added to the denominator for numerical
stability. (default: :obj:`1e-5`)
"""
def __init__(self, in_channels: int, eps: float = 1e-5):
super(GraphNorm, self).__init__()
self.in_channels = in_channels
self.eps = eps
self.weight = torch.nn.Parameter(torch.Tensor(in_channels))
self.bias = torch.nn.Parameter(torch.Tensor(in_channels))
self.mean_scale = torch.nn.Parameter(torch.Tensor(in_channels))
self.reset_parameters()
[docs] def reset_parameters(self):
ones(self.weight)
zeros(self.bias)
ones(self.mean_scale)
[docs] def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor:
""""""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)
batch_size = int(batch.max()) + 1
mean = scatter_mean(x, batch, dim=0, dim_size=batch_size)[batch]
out = x - mean * self.mean_scale
var = scatter_mean(out.pow(2), batch, dim=0, dim_size=batch_size)
std = (var + self.eps).sqrt()[batch]
return self.weight * out / std + self.bias
def __repr__(self):
return f'{self.__class__.__name__}({self.in_channels})'