torch_geometric.nn.aggr.Aggregation
- class Aggregation[source]
Bases:
Module
An abstract base class for implementing custom aggregations.
Aggregation can be either performed via an
index
vector, which defines the mapping from input elements to their location in the output:Notably,
index
does not have to be sorted (for most aggregation operators):# Feature matrix holding 10 elements with 64 features each: x = torch.randn(10, 64) # Assign each element to one of three sets: index = torch.tensor([0, 0, 1, 0, 2, 0, 2, 1, 0, 2]) output = aggr(x, index) # Output shape: [3, 64]
Alternatively, aggregation can be achieved via a “compressed” index vector called
ptr
. Here, elements within the same set need to be grouped together in the input, andptr
defines their boundaries:# Feature matrix holding 10 elements with 64 features each: x = torch.randn(10, 64) # Define the boundary indices for three sets: ptr = torch.tensor([0, 4, 7, 10]) output = aggr(x, ptr=ptr) # Output shape: [4, 64]
Note that at least one of
index
orptr
must be defined.- Shapes:
input: node features \((*, |\mathcal{V}|, F_{in})\) or edge features \((*, |\mathcal{E}|, F_{in})\), index vector \((|\mathcal{V}|)\) or \((|\mathcal{E}|)\),
output: graph features \((*, |\mathcal{G}|, F_{out})\) or node features \((*, |\mathcal{V}|, F_{out})\)
- forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None) Tensor [source]
Forward pass.
- Parameters:
x (torch.Tensor) – The source tensor.
index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of
index
orptr
must be defined. (default:None
)ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of
index
orptr
must be defined. (default:None
)dim_size (int, optional) – The size of the output tensor at dimension
dim
after aggregation. (default:None
)dim (int, optional) – The dimension in which to aggregate. (default:
-2
)max_num_elements – (int, optional): The maximum number of elements within a single aggregation group. (default:
None
)