class DiffGroupNorm(in_channels: int, groups: int, lamda: float = 0.01, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True)[source]

Bases: Module

The differentiable group normalization layer from the “Towards Deeper Graph Neural Networks with Differentiable Group Normalization” paper, which normalizes node features group-wise via a learnable soft cluster assignment

\[\mathbf{S} = \text{softmax} (\mathbf{X} \mathbf{W})\]

where \(\mathbf{W} \in \mathbb{R}^{F \times G}\) denotes a trainable weight matrix mapping each node into one of \(G\) clusters. Normalization is then performed group-wise via:

\[\mathbf{X}^{\prime} = \mathbf{X} + \lambda \sum_{i = 1}^G \text{BatchNorm}(\mathbf{S}[:, i] \odot \mathbf{X})\]
  • in_channels (int) – Size of each input sample \(F\).

  • groups (int) – The number of groups \(G\).

  • lamda (float, optional) – The balancing factor \(\lambda\) between input embeddings and normalized embeddings. (default: 0.01)

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • momentum (float, optional) – The value used for the running mean and running variance computation. (default: 0.1)

  • affine (bool, optional) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: True)

  • track_running_stats (bool, optional) – If set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: True)


Resets all learnable parameters of the module.

forward(x: Tensor) Tensor[source]

x (torch.Tensor) – The source tensor.

static group_distance_ratio(x: Tensor, y: Tensor, eps: float = 1e-05) float[source]

Measures the ratio of inter-group distance over intra-group distance

\[R_{\text{Group}} = \frac{\frac{1}{(C-1)^2} \sum_{i!=j} \frac{1}{|\mathbf{X}_i||\mathbf{X}_j|} \sum_{\mathbf{x}_{iv} \in \mathbf{X}_i } \sum_{\mathbf{x}_{jv^{\prime}} \in \mathbf{X}_j} {\| \mathbf{x}_{iv} - \mathbf{x}_{jv^{\prime}} \|}_2 }{ \frac{1}{C} \sum_{i} \frac{1}{{|\mathbf{X}_i|}^2} \sum_{\mathbf{x}_{iv}, \mathbf{x}_{iv^{\prime}} \in \mathbf{X}_i } {\| \mathbf{x}_{iv} - \mathbf{x}_{iv^{\prime}} \|}_2 }\]

where \(\mathbf{X}_i\) denotes the set of all nodes that belong to class \(i\), and \(C\) denotes the total number of classes in y.