torch_geometric.Index

class Index(data: Any, *args: Any, dim_size: Optional[int] = None, is_sorted: bool = False, **kwargs: Any)[source]

Bases: Tensor

A one-dimensional index tensor with additional (meta)data attached.

Index is a torch.Tensor that holds indices of shape [num_indices].

While Index sub-classes a general torch.Tensor, it can hold additional (meta)data, i.e.:

  • dim_size: The size of the underlying sparse vector size, i.e., the size of a dimension that can be indexed via index. By default, it is inferred as dim_size=index.max() + 1.

  • is_sorted: Whether indices are sorted in ascending order.

Additionally, Index caches data via indptr for fast CSR conversion in case its representation is sorted. Caches are filled based on demand (e.g., when calling Index.get_indptr()), or when explicitly requested via Index.fill_cache_(), and are maintaned and adjusted over its lifespan.

This representation ensures for optimal computation in GNN message passing schemes, while preserving the ease-of-use of regular COO-based workflows.

from torch_geometric import Index

index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
>>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
assert index.dim_size == 3
assert index.is_sorted

# Flipping order:
edge_index.flip(0)
>>> Index([[2, 1, 1, 0], dim_size=3)
assert not index.is_sorted

# Filtering:
mask = torch.tensor([True, True, True, False])
index[:, mask]
>>> Index([[0, 1, 1], dim_size=3, is_sorted=True)
assert index.is_sorted
validate() Index[source]

Validates the Index representation.

In particular, it ensures that

  • it only holds valid indices.

  • the sort order is correctly set.

property dim_size: Optional[int]

The size of the underlying sparse vector.

property is_sorted: bool

Returns whether indices are sorted in ascending order.

get_dim_size() int[source]

The size of the underlying sparse vector. Automatically computed and cached when not explicitly set.

dim_resize_(dim_size: Optional[int]) Index[source]

Assigns or re-assigns the size of the underlying sparse vector.

get_indptr() Tensor[source]

Returns the compressed index representation in case Index is sorted.

fill_cache_() Index[source]

Fills the cache with (meta)data information.

as_tensor() Tensor[source]

Zero-copies the Index representation back to a torch.Tensor representation.