torch_geometric.nn.norm.HeteroBatchNorm
- class HeteroBatchNorm(in_channels: int, num_types: int, eps: float = 1e-05, momentum: Optional[float] = 0.1, affine: bool = True, track_running_stats: bool = True)[source]
Bases:
Module
Applies batch normalization over a batch of heterogeneous features as described in the “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift” paper. Compared to
BatchNorm
,HeteroBatchNorm
applies normalization individually for each node or edge type.- Parameters:
in_channels (int) – Size of each input sample.
num_types (int) – The number of types.
eps (float, optional) – A value added to the denominator for numerical stability. (default:
1e-5
)momentum (float, optional) – The value used for the running mean and running variance computation. (default:
0.1
)affine (bool, optional) – If set to
True
, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default:True
)track_running_stats (bool, optional) – If set to
True
, this module tracks the running mean and variance, and when set toFalse
, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default:True
)
- forward(x: Tensor, type_vec: Tensor) Tensor [source]
Forward pass.
- Parameters:
x (torch.Tensor) – The input features.
type_vec (torch.Tensor) – A vector that maps each entry to a type.
- Return type: