torch_geometric.nn.norm.MessageNorm
- class MessageNorm(learn_scale: bool = False)[source]
Bases:
Module
Applies message normalization over the aggregated messages as described in the “DeeperGCNs: All You Need to Train Deeper GCNs” paper.
\[\mathbf{x}_i^{\prime} = \mathrm{MLP} \left( \mathbf{x}_{i} + s \cdot {\| \mathbf{x}_i \|}_2 \cdot \frac{\mathbf{m}_{i}}{{\|\mathbf{m}_i\|}_2} \right)\]- Parameters:
learn_scale (bool, optional) – If set to
True
, will learn the scaling factor \(s\) of message normalization. (default:False
)
- forward(x: Tensor, msg: Tensor, p: float = 2.0) Tensor [source]
Forward pass.
- Parameters:
x (torch.Tensor) – The source tensor.
msg (torch.Tensor) – The message tensor \(\mathbf{M}\).
p (float, optional) – The norm \(p\) to use for normalization. (default:
2.0
)
- Return type: