torch_geometric.nn.aggr.MultiAggregation
- class MultiAggregation(aggrs: List[Union[Aggregation, str]], aggrs_kwargs: Optional[List[Dict[str, Any]]] = None, mode: Optional[str] = 'cat', mode_kwargs: Optional[Dict[str, Any]] = None)[source]
Bases:
Aggregation
Performs aggregations with one or more aggregators and combines aggregated results, as described in the “Principal Neighbourhood Aggregation for Graph Nets” and “Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions” papers.
- Parameters:
aggrs (list) – The list of aggregation schemes to use.
aggrs_kwargs (dict, optional) – Arguments passed to the respective aggregation function in case it gets automatically resolved. (default:
None
)mode (str, optional) – The combine mode to use for combining aggregated results from multiple aggregations (
"cat"
,"proj"
,"sum"
,"mean"
,"max"
,"min"
,"logsumexp"
,"std"
,"var"
,"attn"
). (default:"cat"
)mode_kwargs (dict, optional) – Arguments passed for the combine
mode
. When"proj"
or"attn"
is used as the combinemode
,in_channels
(int or tuple) andout_channels
(int) are needed to be specified respectively for the size of each input sample to combine from the respective aggregation outputs and the size of each output sample after combination. When"attn"
mode is used,num_heads
(int) is needed to be specified for the number of parallel attention heads. (default:None
)
- forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor [source]
Forward pass.
- Parameters:
x (torch.Tensor) – The source tensor.
index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of
index
orptr
must be defined. (default:None
)ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of
index
orptr
must be defined. (default:None
)dim_size (int, optional) – The size of the output tensor at dimension
dim
after aggregation. (default:None
)dim (int, optional) – The dimension in which to aggregate. (default:
-2
)max_num_elements – (int, optional): The maximum number of elements within a single aggregation group. (default:
None
)