Source code for torch_geometric.nn.norm.graph_size_norm

import torch
import torch.nn as nn

from torch_geometric.utils import degree

[docs]class GraphSizeNorm(nn.Module): r"""Applies Graph Size Normalization over each individual graph in a batch of node features as described in the `"Benchmarking Graph Neural Networks" <>`_ paper .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x}_i}{\sqrt{|\mathcal{V}|}} """ def __init__(self): super().__init__()
[docs] def forward(self, x, batch=None): """""" if batch is None: batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) inv_sqrt_deg = degree(batch, dtype=x.dtype).pow(-0.5) return x * inv_sqrt_deg.index_select(0, batch).view(-1, 1)