Source code for torch_geometric.nn.glob.set2set

from typing import Optional

import torch
from torch import Tensor
from torch_scatter import scatter_add

from torch_geometric.utils import softmax

[docs]class Set2Set(torch.nn.Module): r"""The global pooling operator based on iterative content-based attention from the `"Order Matters: Sequence to sequence for sets" <>`_ paper .. math:: \mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1}) \alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t) \mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i \mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t, where :math:`\mathbf{q}^{*}_T` defines the output of the layer with twice the dimensionality as the input. Args: in_channels (int): Size of each input sample. processing_steps (int): Number of iterations :math:`T`. num_layers (int, optional): Number of recurrent layers, *.e.g*, setting :obj:`num_layers=2` would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. (default: :obj:`1`) Shapes: - **input:** node features :math:`(|\mathcal{V}|, F)`, batch vector :math:`(|\mathcal{V}|)` *(optional)* - **output:** graph features :math:`(|\mathcal{G}|, 2 * F)` where :math:`|\mathcal{G}|` denotes the number of graphs in the batch """ def __init__(self, in_channels: int, processing_steps: int, num_layers: int = 1): super().__init__() self.in_channels = in_channels self.out_channels = 2 * in_channels self.processing_steps = processing_steps self.num_layers = num_layers self.lstm = torch.nn.LSTM(self.out_channels, self.in_channels, num_layers) self.reset_parameters()
[docs] def reset_parameters(self): self.lstm.reset_parameters()
[docs] def forward(self, x: Tensor, batch: Optional[Tensor] = None, size: Optional[int] = None) -> Tensor: r""" Args: x (Tensor): The input node features. batch (LongTensor, optional): A vector that maps each node to its respective graph identifier. (default: :obj:`None`) size (int, optional): The number of graphs in the batch. (default: :obj:`None`) """ if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.int64) size = int(batch.max()) + 1 if size is None else size h = (x.new_zeros((self.num_layers, size, self.in_channels)), x.new_zeros((self.num_layers, size, self.in_channels))) q_star = x.new_zeros(size, self.out_channels) for _ in range(self.processing_steps): q, h = self.lstm(q_star.unsqueeze(0), h) q = q.view(size, self.in_channels) e = (x * q.index_select(0, batch)).sum(dim=-1, keepdim=True) a = softmax(e, batch, num_nodes=size) r = scatter_add(a * x, batch, dim=0, dim_size=size) q_star =[q, r], dim=-1) return q_star
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})')