Source code for torch_geometric.nn.aggr.set_transformer

from typing import Optional

import torch
from torch import Tensor

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.aggr.utils import (
    PoolingByMultiheadAttention,
    SetAttentionBlock,
)


[docs]class SetTransformerAggregation(Aggregation): r"""Performs "Set Transformer" aggregation in which the elements to aggregate are processed by multi-head attention blocks, as described in the `"Graph Neural Networks with Adaptive Readouts" <https://arxiv.org/abs/2211.04952>`_ paper. .. note:: :class:`SetTransformerAggregation` requires sorted indices :obj:`index` as input. Specifically, if you use this aggregation as part of :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that :obj:`edge_index` is sorted by destination nodes, either by manually sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index` or by calling :meth:`torch_geometric.data.Data.sort`. Args: channels (int): Size of each input sample. num_seed_points (int, optional): Number of seed points. (default: :obj:`1`) num_encoder_blocks (int, optional): Number of Set Attention Blocks (SABs) in the encoder. (default: :obj:`1`). num_decoder_blocks (int, optional): Number of Set Attention Blocks (SABs) in the decoder. (default: :obj:`1`). heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the seed embeddings are averaged instead of concatenated. (default: :obj:`True`) layer_norm (str, optional): If set to :obj:`True`, will apply layer normalization. (default: :obj:`False`) dropout (float, optional): Dropout probability of attention weights. (default: :obj:`0`) """ def __init__( self, channels: int, num_seed_points: int = 1, num_encoder_blocks: int = 1, num_decoder_blocks: int = 1, heads: int = 1, concat: bool = True, layer_norm: bool = False, dropout: float = 0.0, ): super().__init__() self.channels = channels self.num_seed_points = num_seed_points self.heads = heads self.concat = concat self.layer_norm = layer_norm self.dropout = dropout self.encoders = torch.nn.ModuleList([ SetAttentionBlock(channels, heads, layer_norm, dropout) for _ in range(num_encoder_blocks) ]) self.pma = PoolingByMultiheadAttention(channels, num_seed_points, heads, layer_norm, dropout) self.decoders = torch.nn.ModuleList([ SetAttentionBlock(channels, heads, layer_norm, dropout) for _ in range(num_decoder_blocks) ])
[docs] def reset_parameters(self): for encoder in self.encoders: encoder.reset_parameters() self.pma.reset_parameters() for decoder in self.decoders: decoder.reset_parameters()
[docs] @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) def forward( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: x, mask = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=max_num_elements) for encoder in self.encoders: x = encoder(x, mask) x = self.pma(x, mask) for decoder in self.decoders: x = decoder(x) x = x.nan_to_num() return x.flatten(1, 2) if self.concat else x.mean(dim=1)
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.channels}, ' f'num_seed_points={self.num_seed_points}, ' f'heads={self.heads}, ' f'layer_norm={self.layer_norm}, ' f'dropout={self.dropout})')