Source code for torch_geometric.nn.aggr.lstm

from typing import Optional

from torch import Tensor
from torch.nn import LSTM

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation


[docs]class LSTMAggregation(Aggregation): r"""Performs LSTM-style aggregation in which the elements to aggregate are interpreted as a sequence, as described in the `"Inductive Representation Learning on Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper. .. note:: :class:`LSTMAggregation` requires sorted indices :obj:`index` as input. Specifically, if you use this aggregation as part of :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that :obj:`edge_index` is sorted by destination nodes, either by manually sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index` or by calling :meth:`torch_geometric.data.Data.sort`. .. warning:: :class:`LSTMAggregation` is not a permutation-invariant operator. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. **kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`. """ def __init__(self, in_channels: int, out_channels: int, **kwargs): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.lstm = LSTM(in_channels, out_channels, batch_first=True, **kwargs) self.reset_parameters()
[docs] def reset_parameters(self): self.lstm.reset_parameters()
[docs] @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) def forward( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=max_num_elements) return self.lstm(x)[0][:, -1]
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})')