class MultiAggregation(aggrs: List[Union[Aggregation, str]], aggrs_kwargs: Optional[List[Dict[str, Any]]] = None, mode: Optional[str] = 'cat', mode_kwargs: Optional[Dict[str, Any]] = None)[source]

Bases: Aggregation

Performs aggregations with one or more aggregators and combines aggregated results, as described in the “Principal Neighbourhood Aggregation for Graph Nets” and “Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions” papers.

  • aggrs (list) – The list of aggregation schemes to use.

  • aggrs_kwargs (dict, optional) – Arguments passed to the respective aggregation function in case it gets automatically resolved. (default: None)

  • mode (str, optional) – The combine mode to use for combining aggregated results from multiple aggregations ("cat", "proj", "sum", "mean", "max", "min", "logsumexp", "std", "var", "attn"). (default: "cat")

  • mode_kwargs (dict, optional) – Arguments passed for the combine mode. When "proj" or "attn" is used as the combine mode, in_channels (int or tuple) and out_channels (int) are needed to be specified respectively for the size of each input sample to combine from the respective aggregation outputs and the size of each output sample after combination. When "attn" mode is used, num_heads (int) is needed to be specified for the number of parallel attention heads. (default: None)


Resets all learnable parameters of the module.

forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor[source]
  • x (torch.Tensor) – The source tensor.

  • index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of index or ptr must be defined. (default: None)

  • ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of index or ptr must be defined. (default: None)

  • dim_size (int, optional) – The size of the output tensor at dimension dim after aggregation. (default: None)

  • dim (int, optional) – The dimension in which to aggregate. (default: -2)