Source code for torch_geometric.nn.norm.diff_group_norm

import torch
from torch import Tensor
from torch.nn import BatchNorm1d, Linear

[docs]class DiffGroupNorm(torch.nn.Module): r"""The differentiable group normalization layer from the `"Towards Deeper Graph Neural Networks with Differentiable Group Normalization" <>`_ paper, which normalizes node features group-wise via a learnable soft cluster assignment. .. math:: \mathbf{S} = \text{softmax} (\mathbf{X} \mathbf{W}) where :math:`\mathbf{W} \in \mathbb{R}^{F \times G}` denotes a trainable weight matrix mapping each node into one of :math:`G` clusters. Normalization is then performed group-wise via: .. math:: \mathbf{X}^{\prime} = \mathbf{X} + \lambda \sum_{i = 1}^G \text{BatchNorm}(\mathbf{S}[:, i] \odot \mathbf{X}) Args: in_channels (int): Size of each input sample :math:`F`. groups (int): The number of groups :math:`G`. lamda (float, optional): The balancing factor :math:`\lambda` between input embeddings and normalized embeddings. (default: :obj:`0.01`) eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) momentum (float, optional): The value used for the running mean and running variance computation. (default: :obj:`0.1`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) track_running_stats (bool, optional): If set to :obj:`True`, this module tracks the running mean and variance, and when set to :obj:`False`, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: :obj:`True`) """ def __init__( self, in_channels: int, groups: int, lamda: float = 0.01, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, ): super().__init__() self.in_channels = in_channels self.groups = groups self.lamda = lamda self.lin = Linear(in_channels, groups, bias=False) self.norm = BatchNorm1d(groups * in_channels, eps, momentum, affine, track_running_stats) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin.reset_parameters() self.norm.reset_parameters()
[docs] def forward(self, x: Tensor) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. """ F, G = self.in_channels, self.groups s = self.lin(x).softmax(dim=-1) # [N, G] out = s.unsqueeze(-1) * x.unsqueeze(-2) # [N, G, F] out = self.norm(out.view(-1, G * F)).view(-1, G, F).sum(-2) # [N, F] return x + self.lamda * out
[docs] @staticmethod def group_distance_ratio(x: Tensor, y: Tensor, eps: float = 1e-5) -> float: r"""Measures the ratio of inter-group distance over intra-group distance. .. math:: R_{\text{Group}} = \frac{\frac{1}{(C-1)^2} \sum_{i!=j} \frac{1}{|\mathbf{X}_i||\mathbf{X}_j|} \sum_{\mathbf{x}_{iv} \in \mathbf{X}_i } \sum_{\mathbf{x}_{jv^{\prime}} \in \mathbf{X}_j} {\| \mathbf{x}_{iv} - \mathbf{x}_{jv^{\prime}} \|}_2 }{ \frac{1}{C} \sum_{i} \frac{1}{{|\mathbf{X}_i|}^2} \sum_{\mathbf{x}_{iv}, \mathbf{x}_{iv^{\prime}} \in \mathbf{X}_i } {\| \mathbf{x}_{iv} - \mathbf{x}_{iv^{\prime}} \|}_2 } where :math:`\mathbf{X}_i` denotes the set of all nodes that belong to class :math:`i`, and :math:`C` denotes the total number of classes in :obj:`y`. """ num_classes = int(y.max()) + 1 numerator = 0. for i in range(num_classes): mask = y == i dist = torch.cdist(x[mask].unsqueeze(0), x[~mask].unsqueeze(0)) numerator += (1 / dist.numel()) * float(dist.sum()) numerator *= 1 / (num_classes - 1)**2 denominator = 0. for i in range(num_classes): mask = y == i dist = torch.cdist(x[mask].unsqueeze(0), x[mask].unsqueeze(0)) denominator += (1 / dist.numel()) * float(dist.sum()) denominator *= 1 / num_classes return numerator / (denominator + eps)
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'groups={self.groups})')