Source code for torch_geometric.nn.norm.mean_subtraction_norm

from typing import Optional

import torch
from torch import Tensor

from torch_geometric.utils import scatter

[docs]class MeanSubtractionNorm(torch.nn.Module): r"""Applies layer normalization by subtracting the mean from the inputs as described in the `"Revisiting 'Over-smoothing' in Deep GCNs" <>`_ paper. .. math:: \mathbf{x}_i = \mathbf{x}_i - \frac{1}{|\mathcal{V}|} \sum_{j \in \mathcal{V}} \mathbf{x}_j """
[docs] def forward(self, x: Tensor, batch: Optional[Tensor] = None, dim_size: Optional[int] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) dim_size (int, optional): The number of examples :math:`B` in case :obj:`batch` is given. (default: :obj:`None`) """ if batch is None: return x - x.mean(dim=0, keepdim=True) mean = scatter(x, batch, dim=0, dim_size=dim_size, reduce='mean') return x - mean[batch]
def __repr__(self) -> str: return f'{self.__class__.__name__}()'