Source code for torch_geometric.nn.aggr.gmt

from typing import Optional

import torch
from torch import Tensor

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.aggr.utils import (
    PoolingByMultiheadAttention,
    SetAttentionBlock,
)


[docs]class GraphMultisetTransformer(Aggregation): r"""The Graph Multiset Transformer pooling operator from the `"Accurate Learning of Graph Representations with Graph Multiset Pooling" <https://arxiv.org/abs/2102.11533>`_ paper. The :class:`GraphMultisetTransformer` aggregates elements into :math:`k` representative elements via attention-based pooling, computes the interaction among them via :obj:`num_encoder_blocks` self-attention blocks, and finally pools the representative elements via attention-based pooling into a single cluster. .. note:: :class:`GraphMultisetTransformer` requires sorted indices :obj:`index` as input. Specifically, if you use this aggregation as part of :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that :obj:`edge_index` is sorted by destination nodes, either by manually sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index` or by calling :meth:`torch_geometric.data.Data.sort`. Args: channels (int): Size of each input sample. k (int): Number of :math:`k` representative nodes after pooling. num_encoder_blocks (int, optional): Number of Set Attention Blocks (SABs) between the two pooling blocks. (default: :obj:`1`) heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) norm (str, optional): If set to :obj:`True`, will apply layer normalization. (default: :obj:`False`) dropout (float, optional): Dropout probability of attention weights. (default: :obj:`0`) """ def __init__( self, channels: int, k: int, num_encoder_blocks: int = 1, heads: int = 1, layer_norm: bool = False, dropout: float = 0.0, ): super().__init__() self.channels = channels self.k = k self.heads = heads self.layer_norm = layer_norm self.dropout = dropout self.pma1 = PoolingByMultiheadAttention(channels, k, heads, layer_norm, dropout) self.encoders = torch.nn.ModuleList([ SetAttentionBlock(channels, heads, layer_norm, dropout) for _ in range(num_encoder_blocks) ]) self.pma2 = PoolingByMultiheadAttention(channels, 1, heads, layer_norm, dropout)
[docs] def reset_parameters(self): self.pma1.reset_parameters() for encoder in self.encoders: encoder.reset_parameters() self.pma2.reset_parameters()
[docs] @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) def forward( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: x, mask = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=max_num_elements) x = self.pma1(x, mask) for encoder in self.encoders: x = encoder(x) x = self.pma2(x) return x.squeeze(1)
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.channels}, ' f'k={self.k}, heads={self.heads}, ' f'layer_norm={self.layer_norm}, ' f'dropout={self.dropout})')