class SetTransformerAggregation(channels: int, num_seed_points: int = 1, num_encoder_blocks: int = 1, num_decoder_blocks: int = 1, heads: int = 1, concat: bool = True, layer_norm: bool = False, dropout: float = 0.0)[source]

Bases: Aggregation

Performs “Set Transformer” aggregation in which the elements to aggregate are processed by multi-head attention blocks, as described in the “Graph Neural Networks with Adaptive Readouts” paper.

  • channels (int) – Size of each input sample.

  • num_seed_points (int, optional) – Number of seed points. (default: 1)

  • num_encoder_blocks (int, optional) – Number of Set Attention Blocks (SABs) in the encoder. (default: 1).

  • num_decoder_blocks (int, optional) – Number of Set Attention Blocks (SABs) in the decoder. (default: 1).

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the seed embeddings are averaged instead of concatenated. (default: True)

  • norm (str, optional) – If set to True, will apply layer normalization. (default: False)

  • dropout (float, optional) – Dropout probability of attention weights. (default: 0)


Resets all learnable parameters of the module.

forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor[source]
  • x (torch.Tensor) – The source tensor.

  • index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of index or ptr must be defined. (default: None)

  • ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of index or ptr must be defined. (default: None)

  • dim_size (int, optional) – The size of the output tensor at dimension dim after aggregation. (default: None)

  • dim (int, optional) – The dimension in which to aggregate. (default: -2)