torch_geometric.nn.aggr.SetTransformerAggregation
- class SetTransformerAggregation(channels: int, num_seed_points: int = 1, num_encoder_blocks: int = 1, num_decoder_blocks: int = 1, heads: int = 1, concat: bool = True, layer_norm: bool = False, dropout: float = 0.0)[source]
Bases:
Aggregation
Performs “Set Transformer” aggregation in which the elements to aggregate are processed by multi-head attention blocks, as described in the “Graph Neural Networks with Adaptive Readouts” paper.
- Parameters
channels (int) – Size of each input sample.
num_seed_points (int, optional) – Number of seed points. (default:
1
)num_encoder_blocks (int, optional) – Number of Set Attention Blocks (SABs) in the encoder. (default:
1
).num_decoder_blocks (int, optional) – Number of Set Attention Blocks (SABs) in the decoder. (default:
1
).heads (int, optional) – Number of multi-head-attentions. (default:
1
)concat (bool, optional) – If set to
False
, the seed embeddings are averaged instead of concatenated. (default:True
)norm (str, optional) – If set to
True
, will apply layer normalization. (default:False
)dropout (float, optional) – Dropout probability of attention weights. (default:
0
)
- forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor [source]
- Parameters
x (torch.Tensor) – The source tensor.
index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of
index
orptr
must be defined. (default:None
)ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of
index
orptr
must be defined. (default:None
)dim_size (int, optional) – The size of the output tensor at dimension
dim
after aggregation. (default:None
)dim (int, optional) – The dimension in which to aggregate. (default:
-2
)