torch_geometric.nn.conv.GatedGraphConv

class GatedGraphConv(out_channels: int, num_layers: int, aggr: str = 'add', bias: bool = True, **kwargs)[source]

Bases: MessagePassing

The gated graph convolution operator from the “Gated Graph Sequence Neural Networks” paper.

\[ \begin{align}\begin{aligned}\mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0}\\\mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)}\\\mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, \mathbf{h}_i^{(l)})\end{aligned}\end{align} \]

up to representation \(\mathbf{h}_i^{(L)}\). The number of input channels of \(\mathbf{x}_i\) needs to be less or equal than out_channels. \(e_{j,i}\) denotes the edge weight from source node j to target node i (default: 1)

Parameters:
  • out_channels (int) – Size of each output sample.

  • num_layers (int) – The sequence length \(L\).

  • aggr (str, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "add")

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\)

forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]

Runs the forward pass of the module.

reset_parameters()[source]

Resets all learnable parameters of the module.