from typing import Optional
from torch import Tensor
from torch.nn import GRU
from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation
[docs]class GRUAggregation(Aggregation):
r"""Performs GRU aggregation in which the elements to aggregate are
interpreted as a sequence, as described in the `"Graph Neural Networks
with Adaptive Readouts" <>`_ paper.
.. note::
:class:`GRUAggregation` requires sorted indices :obj:`index` as input.
Specifically, if you use this aggregation as part of
:class:`~torch_geometric.nn.conv.MessagePassing`, ensure that
:obj:`edge_index` is sorted by destination nodes, either by manually
sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index`
or by calling :meth:``.
.. warning::
:class:`GRUAggregation` is not a permutation-invariant operator.
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
**kwargs (optional): Additional arguments of :class:`torch.nn.GRU`.
def __init__(self, in_channels: int, out_channels: int, **kwargs):
self.in_channels = in_channels
self.out_channels = out_channels
self.gru = GRU(in_channels, out_channels, batch_first=True, **kwargs)
[docs] def reset_parameters(self):
[docs] @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
def forward(
x: Tensor,
index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None,
dim: int = -2,
max_num_elements: Optional[int] = None,
) -> Tensor:
x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,
return self.gru(x)[0][:, -1]
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '