# torch_geometric.nn.aggr.Aggregation

class Aggregation[source]

Bases: Module

An abstract base class for implementing custom aggregations.

Aggregation can be either performed via an index vector, which defines the mapping from input elements to their location in the output: Notably, index does not have to be sorted (for most aggregation operators):

# Feature matrix holding 10 elements with 64 features each:
x = torch.randn(10, 64)

# Assign each element to one of three sets:
index = torch.tensor([0, 0, 1, 0, 2, 0, 2, 1, 0, 2])

output = aggr(x, index)  #  Output shape: [3, 64]


Alternatively, aggregation can be achieved via a “compressed” index vector called ptr. Here, elements within the same set need to be grouped together in the input, and ptr defines their boundaries:

# Feature matrix holding 10 elements with 64 features each:
x = torch.randn(10, 64)

# Define the boundary indices for three sets:
ptr = torch.tensor([0, 4, 7, 10])

output = aggr(x, ptr=ptr)  #  Output shape: [4, 64]


Note that at least one of index or ptr must be defined.

Shapes:
• input: node features $$(*, |\mathcal{V}|, F_{in})$$ or edge features $$(*, |\mathcal{E}|, F_{in})$$, index vector $$(|\mathcal{V}|)$$ or $$(|\mathcal{E}|)$$,

• output: graph features $$(*, |\mathcal{G}|, F_{out})$$ or node features $$(*, |\mathcal{V}|, F_{out})$$

forward(x: Tensor, index: = None, ptr: = None, dim_size: = None, dim: int = -2) [source]
Parameters
• x (torch.Tensor) – The source tensor.

• index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of index or ptr must be defined. (default: None)

• ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of index or ptr must be defined. (default: None)

• dim_size (int, optional) – The size of the output tensor at dimension dim after aggregation. (default: None)

• dim (int, optional) – The dimension in which to aggregate. (default: -2)

reset_parameters()[source]

Resets all learnable parameters of the module.