Source code for torch_geometric.nn.norm.graph_size_norm

from typing import Optional

import torch
from torch import Tensor

from torch_geometric.typing import OptTensor
from torch_geometric.utils import degree


[docs]class GraphSizeNorm(torch.nn.Module): r"""Applies Graph Size Normalization over each individual graph in a batch of node features as described in the `"Benchmarking Graph Neural Networks" <https://arxiv.org/abs/2003.00982>`_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x}_i}{\sqrt{|\mathcal{V}|}} """ def __init__(self): super().__init__()
[docs] def forward(self, x: Tensor, batch: OptTensor = None, batch_size: Optional[int] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ if batch is None: batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) batch_size = 1 inv_sqrt_deg = degree(batch, batch_size, dtype=x.dtype).pow(-0.5) return x * inv_sqrt_deg.index_select(0, batch).view(-1, 1)
def __repr__(self) -> str: return f'{self.__class__.__name__}()'