torch_geometric.nn.aggr.SortAggregation
- class SortAggregation(k: int)[source]
Bases:
AggregationThe pooling operator from the “An End-to-End Deep Learning Architecture for Graph Classification” paper, where node features are sorted in descending order based on their last feature channel. The first \(k\) nodes form the output of the layer.
Note
SortAggregationrequires sorted indicesindexas input. Specifically, if you use this aggregation as part ofMessagePassing, ensure thatedge_indexis sorted by destination nodes, either by manually sorting edge indices viasort_edge_index()or by callingtorch_geometric.data.Data.sort().- Parameters:
k (int) – The number of nodes to hold for each graph.
- forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None) Tensor[source]
Forward pass.
- Parameters:
x (torch.Tensor) – The source tensor.
index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of
indexorptrmust be defined. (default:None)ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of
indexorptrmust be defined. (default:None)dim_size (int, optional) – The size of the output tensor at dimension
dimafter aggregation. (default:None)dim (int, optional) – The dimension in which to aggregate. (default:
-2)max_num_elements (
Optional[int], default:None) – (int, optional): The maximum number of elements within a single aggregation group. (default:None)
- Return type: