from typing import Optional
import torch
from torch import Tensor
from torch_geometric.nn.inits import ones, zeros
from torch_geometric.typing import OptTensor
from torch_geometric.utils import scatter
[docs]class GraphNorm(torch.nn.Module):
r"""Applies graph normalization over individual graphs as described in the
`"GraphNorm: A Principled Approach to Accelerating Graph Neural Network
Training" <https://arxiv.org/abs/2009.03294>`_ paper.
.. math::
\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \alpha \odot
\textrm{E}[\mathbf{x}]}
{\sqrt{\textrm{Var}[\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]]
+ \epsilon}} \odot \gamma + \beta
where :math:`\alpha` denotes parameters that learn how much information
to keep in the mean.
Args:
in_channels (int): Size of each input sample.
eps (float, optional): A value added to the denominator for numerical
stability. (default: :obj:`1e-5`)
"""
def __init__(self, in_channels: int, eps: float = 1e-5):
super().__init__()
self.in_channels = in_channels
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(in_channels))
self.bias = torch.nn.Parameter(torch.empty(in_channels))
self.mean_scale = torch.nn.Parameter(torch.empty(in_channels))
self.reset_parameters()
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
ones(self.weight)
zeros(self.bias)
ones(self.mean_scale)
[docs] def forward(self, x: Tensor, batch: OptTensor = None,
batch_size: Optional[int] = None) -> Tensor:
r"""Forward pass.
Args:
x (torch.Tensor): The source tensor.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example. (default: :obj:`None`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
"""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)
batch_size = 1
if batch_size is None:
batch_size = int(batch.max()) + 1
mean = scatter(x, batch, 0, batch_size, reduce='mean')
out = x - mean.index_select(0, batch) * self.mean_scale
var = scatter(out.pow(2), batch, 0, batch_size, reduce='mean')
std = (var + self.eps).sqrt().index_select(0, batch)
return self.weight * out / std + self.bias
def __repr__(self):
return f'{self.__class__.__name__}({self.in_channels})'