torch_geometric.nn.conv.GatedGraphConv
- class GatedGraphConv(out_channels: int, num_layers: int, aggr: str = 'add', bias: bool = True, **kwargs)[source]
Bases:
MessagePassing
The gated graph convolution operator from the “Gated Graph Sequence Neural Networks” paper.
\[ \begin{align}\begin{aligned}\mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0}\\\mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)}\\\mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, \mathbf{h}_i^{(l)})\end{aligned}\end{align} \]up to representation \(\mathbf{h}_i^{(L)}\). The number of input channels of \(\mathbf{x}_i\) needs to be less or equal than
out_channels
. \(e_{j,i}\) denotes the edge weight from source nodej
to target nodei
(default:1
)- Parameters:
out_channels (int) – Size of each output sample.
num_layers (int) – The sequence length \(L\).
aggr (str, optional) – The aggregation scheme to use (
"add"
,"mean"
,"max"
). (default:"add"
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\)