- class GraphMultisetTransformer(channels: int, k: int, num_encoder_blocks: int = 1, heads: int = 1, layer_norm: bool = False, dropout: float = 0.0)
The Graph Multiset Transformer pooling operator from the “Accurate Learning of Graph Representations with Graph Multiset Pooling” paper.
GraphMultisetTransformeraggregates elements into \(k\) representative elements via attention-based pooling, computes the interaction among them via
num_encoder_blocksself-attention blocks, and finally pools the representative elements via attention-based pooling into a single cluster.
channels (int) – Size of each input sample.
k (int) – Number of \(k\) representative nodes after pooling.
num_encoder_blocks (int, optional) – Number of Set Attention Blocks (SABs) between the two pooling blocks. (default:
heads (int, optional) – Number of multi-head-attentions. (default:
dropout (float, optional) – Dropout probability of attention weights. (default:
Resets all learnable parameters of the module.
- forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor
x (torch.Tensor) – The source tensor.
dim (int, optional) – The dimension in which to aggregate. (default: