torch_geometric.nn.norm.HeteroBatchNorm

class HeteroBatchNorm(in_channels: int, num_types: int, eps: float = 1e-05, momentum: Optional[float] = 0.1, affine: bool = True, track_running_stats: bool = True)[source]

Bases: Module

Applies batch normalization over a batch of heterogeneous features as described in the “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift” paper. Compared to BatchNorm, HeteroBatchNorm applies normalization individually for each node or edge type.

Parameters:
  • in_channels (int) – Size of each input sample.

  • num_types (int) – The number of types.

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • momentum (float, optional) – The value used for the running mean and running variance computation. (default: 0.1)

  • affine (bool, optional) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: True)

  • track_running_stats (bool, optional) – If set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: True)

reset_running_stats()[source]

Resets all running statistics of the module.

reset_parameters()[source]

Resets all learnable parameters of the module.

forward(x: Tensor, type_vec: Tensor) Tensor[source]

Forward pass.

Parameters: