torch_geometric.Index

class Index(data: Any, *args: Any, dim_size: Optional[int] = None, is_sorted: bool = False, **kwargs: Any)[source]

Bases: Tensor

A one-dimensional index tensor with additional (meta)data attached.

Index is a torch.Tensor that holds indices of shape [num_indices].

While Index sub-classes a general torch.Tensor, it can hold additional (meta)data, i.e.:

  • dim_size: The size of the underlying sparse vector size, i.e., the size of a dimension that can be indexed via index. By default, it is inferred as dim_size=index.max() + 1.

  • is_sorted: Whether indices are sorted in ascending order.

Additionally, Index caches data via indptr for fast CSR conversion in case its representation is sorted. Caches are filled based on demand (e.g., when calling Index.get_indptr()), or when explicitly requested via Index.fill_cache_(), and are maintaned and adjusted over its lifespan.

This representation ensures optimal computation in GNN message passing schemes, while preserving the ease-of-use of regular COO-based workflows.

from torch_geometric import Index

index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
>>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
assert index.dim_size == 3
assert index.is_sorted

# Flipping order:
index.flip(0)
>>> Index([[2, 1, 1, 0], dim_size=3)
assert not index.is_sorted

# Filtering:
mask = torch.tensor([True, True, True, False])
index[:, mask]
>>> Index([[0, 1, 1], dim_size=3, is_sorted=True)
assert index.is_sorted
validate() Index[source]

Validates the Index representation.

In particular, it ensures that

  • it only holds valid indices.

  • the sort order is correctly set.

Return type:

Index

property dim_size: Optional[int]

The size of the underlying sparse vector.

Return type:

Optional[int]

property is_sorted: bool

Returns whether indices are sorted in ascending order.

Return type:

bool

get_dim_size() int[source]

The size of the underlying sparse vector. Automatically computed and cached when not explicitly set.

Return type:

int

dim_resize_(dim_size: Optional[int]) Index[source]

Assigns or re-assigns the size of the underlying sparse vector.

Return type:

Index

get_indptr() Tensor[source]

Returns the compressed index representation in case Index is sorted.

Return type:

Tensor

fill_cache_() Index[source]

Fills the cache with (meta)data information.

Return type:

Index

as_tensor() Tensor[source]

Zero-copies the Index representation back to a torch.Tensor representation.

Return type:

Tensor