torch_geometric.Index
- class Index(data: Any, *args: Any, dim_size: Optional[int] = None, is_sorted: bool = False, **kwargs: Any)[source]
Bases:
Tensor
A one-dimensional
index
tensor with additional (meta)data attached.Index
is atorch.Tensor
that holds indices of shape[num_indices]
.While
Index
sub-classes a generaltorch.Tensor
, it can hold additional (meta)data, i.e.:dim_size
: The size of the underlying sparse vector size, i.e., the size of a dimension that can be indexed viaindex
. By default, it is inferred asdim_size=index.max() + 1
.is_sorted
: Whether indices are sorted in ascending order.
Additionally,
Index
caches data viaindptr
for fast CSR conversion in case its representation is sorted. Caches are filled based on demand (e.g., when callingIndex.get_indptr()
), or when explicitly requested viaIndex.fill_cache_()
, and are maintaned and adjusted over its lifespan.This representation ensures optimal computation in GNN message passing schemes, while preserving the ease-of-use of regular COO-based PyG workflows.
from torch_geometric import Index index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) >>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True) assert index.dim_size == 3 assert index.is_sorted # Flipping order: index.flip(0) >>> Index([[2, 1, 1, 0], dim_size=3) assert not index.is_sorted # Filtering: mask = torch.tensor([True, True, True, False]) index[:, mask] >>> Index([[0, 1, 1], dim_size=3, is_sorted=True) assert index.is_sorted
- validate() Index [source]
Validates the
Index
representation.In particular, it ensures that
it only holds valid indices.
the sort order is correctly set.
- Return type:
- get_dim_size() int [source]
The size of the underlying sparse vector. Automatically computed and cached when not explicitly set.
- Return type:
- dim_resize_(dim_size: Optional[int]) Index [source]
Assigns or re-assigns the size of the underlying sparse vector.
- Return type: