class SetTransformerAggregation(channels: int, num_seed_points: int = 1, num_encoder_blocks: int = 1, num_decoder_blocks: int = 1, heads: int = 1, concat: bool = True, layer_norm: bool = False, dropout: float = 0.0)[source]

Bases: Aggregation

Performs “Set Transformer” aggregation in which the elements to aggregate are processed by multi-head attention blocks, as described in the “Graph Neural Networks with Adaptive Readouts” paper.


SetTransformerAggregation requires sorted indices index as input. Specifically, if you use this aggregation as part of MessagePassing, ensure that edge_index is sorted by destination nodes, either by manually sorting edge indices via sort_edge_index() or by calling

  • channels (int) – Size of each input sample.

  • num_seed_points (int, optional) – Number of seed points. (default: 1)

  • num_encoder_blocks (int, optional) – Number of Set Attention Blocks (SABs) in the encoder. (default: 1).

  • num_decoder_blocks (int, optional) – Number of Set Attention Blocks (SABs) in the decoder. (default: 1).

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the seed embeddings are averaged instead of concatenated. (default: True)

  • norm (str, optional) – If set to True, will apply layer normalization. (default: False)

  • dropout (float, optional) – Dropout probability of attention weights. (default: 0)


Resets all learnable parameters of the module.

forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None) Tensor[source]

Forward pass.

  • x (torch.Tensor) – The source tensor.

  • index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of index or ptr must be defined. (default: None)

  • ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of index or ptr must be defined. (default: None)

  • dim_size (int, optional) – The size of the output tensor at dimension dim after aggregation. (default: None)

  • dim (int, optional) – The dimension in which to aggregate. (default: -2)

  • max_num_elements – (int, optional): The maximum number of elements within a single aggregation group. (default: None)