torch_geometric.nn.norm.DiffGroupNorm
- class DiffGroupNorm(in_channels: int, groups: int, lamda: float = 0.01, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True)[source]
Bases:
Module
The differentiable group normalization layer from the “Towards Deeper Graph Neural Networks with Differentiable Group Normalization” paper, which normalizes node features group-wise via a learnable soft cluster assignment.
\[\mathbf{S} = \text{softmax} (\mathbf{X} \mathbf{W})\]where \(\mathbf{W} \in \mathbb{R}^{F \times G}\) denotes a trainable weight matrix mapping each node into one of \(G\) clusters. Normalization is then performed group-wise via:
\[\mathbf{X}^{\prime} = \mathbf{X} + \lambda \sum_{i = 1}^G \text{BatchNorm}(\mathbf{S}[:, i] \odot \mathbf{X})\]- Parameters:
in_channels (int) – Size of each input sample \(F\).
groups (int) – The number of groups \(G\).
lamda (float, optional) – The balancing factor \(\lambda\) between input embeddings and normalized embeddings. (default:
0.01
)eps (float, optional) – A value added to the denominator for numerical stability. (default:
1e-5
)momentum (float, optional) – The value used for the running mean and running variance computation. (default:
0.1
)affine (bool, optional) – If set to
True
, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default:True
)track_running_stats (bool, optional) – If set to
True
, this module tracks the running mean and variance, and when set toFalse
, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default:True
)
- forward(x: Tensor) Tensor [source]
Forward pass.
- Parameters:
x (torch.Tensor) – The source tensor.
- static group_distance_ratio(x: Tensor, y: Tensor, eps: float = 1e-05) float [source]
Measures the ratio of inter-group distance over intra-group distance.
\[R_{\text{Group}} = \frac{\frac{1}{(C-1)^2} \sum_{i!=j} \frac{1}{|\mathbf{X}_i||\mathbf{X}_j|} \sum_{\mathbf{x}_{iv} \in \mathbf{X}_i } \sum_{\mathbf{x}_{jv^{\prime}} \in \mathbf{X}_j} {\| \mathbf{x}_{iv} - \mathbf{x}_{jv^{\prime}} \|}_2 }{ \frac{1}{C} \sum_{i} \frac{1}{{|\mathbf{X}_i|}^2} \sum_{\mathbf{x}_{iv}, \mathbf{x}_{iv^{\prime}} \in \mathbf{X}_i } {\| \mathbf{x}_{iv} - \mathbf{x}_{iv^{\prime}} \|}_2 }\]where \(\mathbf{X}_i\) denotes the set of all nodes that belong to class \(i\), and \(C\) denotes the total number of classes in
y
.