torch_geometric.nn.models.JumpingKnowledge

class JumpingKnowledge(mode: str, channels: Optional[int] = None, num_layers: Optional[int] = None)[source]

Bases: Module

The Jumping Knowledge layer aggregation module from the “Representation Learning on Graphs with Jumping Knowledge Networks” paper.

Jumping knowledge is performed based on either concatenation ("cat")

\[\mathbf{x}_v^{(1)} \, \Vert \, \ldots \, \Vert \, \mathbf{x}_v^{(T)},\]

max pooling ("max")

\[\max \left( \mathbf{x}_v^{(1)}, \ldots, \mathbf{x}_v^{(T)} \right),\]

or weighted summation

\[\sum_{t=1}^T \alpha_v^{(t)} \mathbf{x}_v^{(t)}\]

with attention scores \(\alpha_v^{(t)}\) obtained from a bi-directional LSTM ("lstm").

Parameters:
  • mode (str) – The aggregation scheme to use ("cat", "max" or "lstm").

  • channels (int, optional) – The number of channels per representation. Needs to be only set for LSTM-style aggregation. (default: None)

  • num_layers (int, optional) – The number of layers to aggregate. Needs to be only set for LSTM-style aggregation. (default: None)

forward(xs: List[Tensor]) Tensor[source]

Forward pass.

Parameters:

xs (List[torch.Tensor]) – List containing the layer-wise representations.

reset_parameters()[source]

Resets all learnable parameters of the module.