Source code for torch_geometric.index

import functools
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    NamedTuple,
    Optional,
    Tuple,
    Type,
    Union,
)

import numpy as np
import torch
import torch.utils._pytree as pytree
from torch import Tensor

from torch_geometric.typing import INDEX_DTYPES

aten = torch.ops.aten

HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}


def ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor:
    index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)
    return index.repeat_interleave(ptr.diff(), output_size=output_size)


def index2ptr(index: Tensor, size: Optional[int] = None) -> Tensor:
    if size is None:
        size = int(index.max()) + 1 if index.numel() > 0 else 0

    return torch._convert_indices_from_coo_to_csr(
        index, size, out_int32=index.dtype != torch.int64)


class CatMetadata(NamedTuple):
    nnz: List[int]
    dim_size: List[Optional[int]]
    is_sorted: List[bool]


def implements(torch_function: Callable) -> Callable:
    r"""Registers a :pytorch:`PyTorch` function override."""
    @functools.wraps(torch_function)
    def decorator(my_function: Callable) -> Callable:
        HANDLED_FUNCTIONS[torch_function] = my_function
        return my_function

    return decorator


def assert_valid_dtype(tensor: Tensor) -> None:
    if tensor.dtype not in INDEX_DTYPES:
        raise ValueError(f"'Index' holds an unsupported data type "
                         f"(got '{tensor.dtype}', but expected one of "
                         f"{INDEX_DTYPES})")


def assert_one_dimensional(tensor: Tensor) -> None:
    if tensor.dim() != 1:
        raise ValueError(f"'Index' needs to be one-dimensional "
                         f"(got {tensor.dim()} dimensions)")


def assert_contiguous(tensor: Tensor) -> None:
    if not tensor.is_contiguous():
        raise ValueError("'Index' needs to be contiguous. Please call "
                         "`index.contiguous()` before proceeding.")


def assert_sorted(func: Callable) -> Callable:
    @functools.wraps(func)
    def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any:
        if not self.is_sorted:
            cls_name = self.__class__.__name__
            raise ValueError(
                f"Cannot call '{func.__name__}' since '{cls_name}' is not "
                f"sorted. Please call `{cls_name}.sort()` first.")
        return func(self, *args, **kwargs)

    return wrapper


[docs]class Index(Tensor): r"""A one-dimensional :obj:`index` tensor with additional (meta)data attached. :class:`Index` is a :pytorch:`null` :class:`torch.Tensor` that holds indices of shape :obj:`[num_indices]`. While :class:`Index` sub-classes a general :pytorch:`null` :class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*: * :obj:`dim_size`: The size of the underlying sparse vector size, *i.e.*, the size of a dimension that can be indexed via :obj:`index`. By default, it is inferred as :obj:`dim_size=index.max() + 1`. * :obj:`is_sorted`: Whether indices are sorted in ascending order. Additionally, :class:`Index` caches data via :obj:`indptr` for fast CSR conversion in case its representation is sorted. Caches are filled based on demand (*e.g.*, when calling :meth:`Index.get_indptr`), or when explicitly requested via :meth:`Index.fill_cache_`, and are maintaned and adjusted over its lifespan. This representation ensures optimal computation in GNN message passing schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG` workflows. .. code-block:: python from torch_geometric import Index index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) >>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True) assert index.dim_size == 3 assert index.is_sorted # Flipping order: index.flip(0) >>> Index([[2, 1, 1, 0], dim_size=3) assert not index.is_sorted # Filtering: mask = torch.tensor([True, True, True, False]) index[:, mask] >>> Index([[0, 1, 1], dim_size=3, is_sorted=True) assert index.is_sorted """ # See "https://pytorch.org/docs/stable/notes/extending.html" # for a basic tutorial on how to subclass `torch.Tensor`. # The underlying tensor representation: _data: Tensor # The size of the underlying sparse vector, e.g. `_data.max() + 1` : _dim_size: Optional[int] = None # Whether the `index` representation is sorted: _is_sorted: bool = False # A cache for its compressed representation: _indptr: Optional[Tensor] = None # Whenever we perform a concatenation of indices, we cache the original # metadata to be able to reconstruct individual indices: _cat_metadata: Optional[CatMetadata] = None @staticmethod def __new__( cls: Type, data: Any, *args: Any, dim_size: Optional[int] = None, is_sorted: bool = False, **kwargs: Any, ) -> 'Index': if not isinstance(data, Tensor): data = torch.tensor(data, *args, **kwargs) elif len(args) > 0: raise TypeError( f"new() received an invalid combination of arguments - got " f"(Tensor, {', '.join(str(type(arg)) for arg in args)})") elif len(kwargs) > 0: raise TypeError(f"new() received invalid keyword arguments - got " f"{set(kwargs.keys())})") assert isinstance(data, Tensor) indptr: Optional[Tensor] = None if isinstance(data, cls): # If passed `Index`, inherit metadata: indptr = data._indptr dim_size = dim_size or data.dim_size is_sorted = is_sorted or data.is_sorted assert_valid_dtype(data) assert_one_dimensional(data) assert_contiguous(data) out = Tensor._make_wrapper_subclass( # type: ignore cls, size=data.size(), strides=data.stride(), dtype=data.dtype, device=data.device, layout=data.layout, requires_grad=False, ) assert isinstance(out, Index) # Attach metadata: out._data = data out._dim_size = dim_size out._is_sorted = is_sorted out._indptr = indptr if isinstance(data, cls): out._data = data._data # Reset metadata if cache is invalidated: if dim_size is not None and dim_size != data.dim_size: out._indptr = None return out # Validation ##############################################################
[docs] def validate(self) -> 'Index': r"""Validates the :class:`Index` representation. In particular, it ensures that * it only holds valid indices. * the sort order is correctly set. """ assert_valid_dtype(self._data) assert_one_dimensional(self._data) assert_contiguous(self._data) if self.numel() > 0 and self._data.min() < 0: raise ValueError(f"'{self.__class__.__name__}' contains negative " f"indices (got {int(self.min())})") if (self.numel() > 0 and self.dim_size is not None and self._data.max() >= self.dim_size): raise ValueError(f"'{self.__class__.__name__}' contains larger " f"indices than its registered size " f"(got {int(self._data.max())}, but expected " f"values smaller than {self.dim_size})") if self.is_sorted and (self._data.diff() < 0).any(): raise ValueError(f"'{self.__class__.__name__}' is not sorted") return self
# Properties ############################################################## @property def dim_size(self) -> Optional[int]: r"""The size of the underlying sparse vector.""" return self._dim_size @property def is_sorted(self) -> bool: r"""Returns whether indices are sorted in ascending order.""" return self._is_sorted @property def dtype(self) -> torch.dtype: # type: ignore # TODO Remove once PyTorch does not override `dtype` in `DataLoader`. return self._data.dtype # Cache Interface #########################################################
[docs] def get_dim_size(self) -> int: r"""The size of the underlying sparse vector. Automatically computed and cached when not explicitly set. """ if self._dim_size is None: dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0 self._dim_size = dim_size assert isinstance(self._dim_size, int) return self._dim_size
[docs] def dim_resize_(self, dim_size: Optional[int]) -> 'Index': r"""Assigns or re-assigns the size of the underlying sparse vector.""" if self.is_sorted and self._indptr is not None: if dim_size is None: self._indptr = None elif self._indptr.numel() - 1 >= dim_size: self._indptr = self._indptr[:dim_size + 1] else: fill_value = self._indptr.new_full( (dim_size - self._indptr.numel() + 1, ), fill_value=self._indptr[-1], # type: ignore ) self._indptr = torch.cat([self._indptr, fill_value], dim=0) self._dim_size = dim_size return self
[docs] @assert_sorted def get_indptr(self) -> Tensor: r"""Returns the compressed index representation in case :class:`Index` is sorted. """ if self._indptr is None: self._indptr = index2ptr(self._data, self.get_dim_size()) assert isinstance(self._indptr, Tensor) return self._indptr
[docs] def fill_cache_(self) -> 'Index': r"""Fills the cache with (meta)data information.""" self.get_dim_size() if self.is_sorted: self.get_indptr() return self
# Methods ################################################################# def share_memory_(self) -> 'Index': """""" # noqa: D419 self._data.share_memory_() if self._indptr is not None: self._indptr.share_memory_() return self def is_shared(self) -> bool: """""" # noqa: D419 return self._data.is_shared()
[docs] def as_tensor(self) -> Tensor: r"""Zero-copies the :class:`Index` representation back to a :class:`torch.Tensor` representation. """ return self._data
# PyTorch/Python builtins ################################################# def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]: attrs = ['_data'] if self._indptr is not None: attrs.append('_indptr') ctx = ( self._dim_size, self._is_sorted, self._cat_metadata, ) return attrs, ctx @staticmethod def __tensor_unflatten__( inner_tensors: Dict[str, Any], ctx: Tuple[Any, ...], outer_size: Tuple[int, ...], outer_stride: Tuple[int, ...], ) -> 'Index': index = Index( inner_tensors['_data'], dim_size=ctx[0], is_sorted=ctx[1], ) index._indptr = inner_tensors.get('_indptr', None) index._cat_metadata = ctx[2] return index # Prevent auto-wrapping outputs back into the proper subclass type: __torch_function__ = torch._C._disabled_torch_function_impl @classmethod def __torch_dispatch__( cls: Type, func: Callable[..., Any], types: Iterable[Type[Any]], args: Iterable[Tuple[Any, ...]] = (), kwargs: Optional[Dict[Any, Any]] = None, ) -> Any: # `Index` should be treated as a regular PyTorch tensor for all # standard PyTorch functionalities. However, # * some of its metadata can be transferred to new functions, e.g., # `torch.narrow()` can inherit the `is_sorted` property. # * not all operations lead to valid `Index` tensors again, e.g., # `torch.sum()` does not yield a `Index` as its output, or # `torch.stack() violates the [*] shape assumption. # To account for this, we hold a number of `HANDLED_FUNCTIONS` that # implement specific functions for valid `Index` routines. if func in HANDLED_FUNCTIONS: return HANDLED_FUNCTIONS[func](*args, **(kwargs or {})) # For all other PyTorch functions, we treat them as vanilla tensors. args = pytree.tree_map_only(Index, lambda x: x._data, args) if kwargs is not None: kwargs = pytree.tree_map_only(Index, lambda x: x._data, kwargs) return func(*args, **(kwargs or {})) def __repr__(self) -> str: # type: ignore prefix = f'{self.__class__.__name__}(' indent = len(prefix) tensor_str = torch._tensor_str._tensor_str(self._data, indent) suffixes = [] if self.dim_size is not None: suffixes.append(f'dim_size={self.dim_size}') if (self.device.type != torch._C._get_default_device() or (self.device.type == 'cuda' and torch.cuda.current_device() != self.device.index) or (self.device.type == 'mps')): suffixes.append(f"device='{self.device}'") if self.dtype != torch.int64: suffixes.append(f'dtype={self.dtype}') if self.is_sorted: suffixes.append('is_sorted=True') return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=False) def tolist(self) -> List[Any]: """""" # noqa: D419 return self._data.tolist() def numpy(self, *, force: bool = False) -> np.ndarray: """""" # noqa: D419 return self._data.numpy(force=force) # Helpers ################################################################# def _shallow_copy(self) -> 'Index': out = Index(self._data) out._dim_size = self._dim_size out._is_sorted = self._is_sorted out._indptr = self._indptr out._cat_metadata = self._cat_metadata return out def _clear_metadata(self) -> 'Index': self._dim_size = None self._is_sorted = False self._indptr = None self._cat_metadata = None return self
def apply_( tensor: Index, fn: Callable, *args: Any, **kwargs: Any, ) -> Union[Index, Tensor]: data = fn(tensor._data, *args, **kwargs) if data.dtype not in INDEX_DTYPES: return data if tensor._data.data_ptr() != data.data_ptr(): out = Index(data) else: # In-place: tensor._data = data out = tensor # Copy metadata: out._dim_size = tensor._dim_size out._is_sorted = tensor._is_sorted out._cat_metadata = tensor._cat_metadata # Convert cache: if tensor._indptr is not None: out._indptr = fn(tensor._indptr, *args, **kwargs) return out @implements(aten.clone.default) def _clone( tensor: Index, *, memory_format: torch.memory_format = torch.preserve_format, ) -> Index: out = apply_(tensor, aten.clone.default, memory_format=memory_format) assert isinstance(out, Index) return out @implements(aten._to_copy.default) def _to_copy( tensor: Index, *, dtype: Optional[torch.dtype] = None, layout: Optional[torch.layout] = None, device: Optional[torch.device] = None, pin_memory: bool = False, non_blocking: bool = False, memory_format: Optional[torch.memory_format] = None, ) -> Union[Index, Tensor]: return apply_( tensor, aten._to_copy.default, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, non_blocking=non_blocking, memory_format=memory_format, ) @implements(aten.alias.default) def _alias(tensor: Index) -> Index: return tensor._shallow_copy() @implements(aten._pin_memory.default) def _pin_memory(tensor: Index) -> Index: out = apply_(tensor, aten._pin_memory.default) assert isinstance(out, Index) return out @implements(aten.sort.default) def _sort( tensor: Index, dim: int = -1, descending: bool = False, ) -> Tuple[Index, Tensor]: if tensor.is_sorted and not descending: return tensor, torch.arange(tensor._data.numel(), device=tensor._data.device) data, perm = aten.sort.default(tensor._data, dim, descending) out = Index(data) out._dim_size = tensor._dim_size if not descending: out._is_sorted = True return out, perm @implements(aten.sort.stable) def _sort_stable( tensor: Index, *, stable: bool = False, dim: int = -1, descending: bool = False, ) -> Tuple[Index, Tensor]: if tensor.is_sorted and not descending: return tensor, torch.arange(tensor._data.numel(), device=tensor._data.device) data, perm = aten.sort.stable(tensor._data, stable=stable, dim=dim, descending=descending) out = Index(data) out._dim_size = tensor._dim_size if not descending: out._is_sorted = True return out, perm @implements(aten.cat.default) def _cat( tensors: List[Union[Index, Tensor]], dim: int = 0, ) -> Union[Index, Tensor]: data_list = pytree.tree_map_only(Index, lambda x: x._data, tensors) data = aten.cat.default(data_list, dim=dim) if any([not isinstance(tensor, Index) for tensor in tensors]): return data out = Index(data) nnz_list = [t.numel() for t in tensors] dim_size_list = [t.dim_size for t in tensors] # type: ignore is_sorted_list = [t.is_sorted for t in tensors] # type: ignore # Post-process `dim_size`: total_dim_size: Optional[int] = 0 for dim_size in dim_size_list: if dim_size is None: total_dim_size = None break assert isinstance(total_dim_size, int) total_dim_size = max(dim_size, total_dim_size) out._dim_size = total_dim_size out._cat_metadata = CatMetadata( nnz=nnz_list, dim_size=dim_size_list, is_sorted=is_sorted_list, ) return out @implements(aten.flip.default) def _flip( input: Index, dims: Union[List[int], Tuple[int, ...]], ) -> Index: data = aten.flip.default(input._data, dims) out = Index(data) out._dim_size = input.dim_size return out @implements(aten.index_select.default) def _index_select( input: Union[Index, Tensor], dim: int, index: Union[Index, Tensor], ) -> Union[Index, Tensor]: out = aten.index_select.default( input._data if isinstance(input, Index) else input, dim, index._data if isinstance(index, Index) else index, ) if isinstance(input, Index): out = Index(out) out._dim_size = input.dim_size return out @implements(aten.slice.Tensor) def _slice( input: Index, dim: int, start: Optional[int] = None, end: Optional[int] = None, step: int = 1, ) -> Index: if ((start is None or start <= 0 or start <= -input.size(dim)) and (end is None or end > input.size(dim)) and step == 1): return input._shallow_copy() # No-op. data = aten.slice.Tensor(input._data, dim, start, end, step) if step != 1: data = data.contiguous() out = Index(data) out._dim_size = input.dim_size # NOTE We could potentially maintain the `indptr` attribute here, # but it is not really clear if this is worth it. The most important # information `is_sorted` needs to be maintained though: if step >= 0: out._is_sorted = input.is_sorted return out @implements(aten.index.Tensor) def _index( input: Union[Index, Tensor], indices: List[Optional[Union[Tensor, Index]]], ) -> Union[Index, Tensor]: if not isinstance(input, Index): indices = pytree.tree_map_only(Index, lambda x: x._data, indices) return aten.index.Tensor(input, indices) data = aten.index.Tensor(input._data, indices) if data.dim() != 1: return data assert len(indices) == 1 index = indices[0] assert index is not None out = Index(data) if index.dtype in (torch.bool, torch.uint8): # 1. `index[mask]`. out._dim_size = input.dim_size out._is_sorted = input.is_sorted else: # 2. `index[index]`. out._dim_size = input.dim_size return out @implements(aten.add.Tensor) def _add( input: Union[int, Tensor, Index], other: Union[int, Tensor, Index], *, alpha: int = 1, ) -> Union[Index, Tensor]: data = aten.add.Tensor( input._data if isinstance(input, Index) else input, other._data if isinstance(other, Index) else other, alpha=alpha, ) if data.dtype not in INDEX_DTYPES: return data if data.dim() != 1: return data out = Index(data) if isinstance(input, Tensor) and input.numel() <= 1: input = int(input) if isinstance(other, Tensor) and other.numel() <= 1: other = int(other) if isinstance(other, int): assert isinstance(input, Index) if input.dim_size is not None: out._dim_size = input.dim_size + alpha * other out._is_sorted = input.is_sorted elif isinstance(input, int): assert isinstance(other, Index) if other.dim_size is not None: out._dim_size = input + alpha * other.dim_size out._is_sorted = other.is_sorted elif isinstance(input, Index) and isinstance(other, Index): if input.dim_size is not None and other.dim_size is not None: out._dim_size = input.dim_size + alpha * other.dim_size return out @implements(aten.add_.Tensor) def add_( input: Index, other: Union[int, Tensor, Index], *, alpha: int = 1, ) -> Index: dim_size = input.dim_size is_sorted = input.is_sorted input._clear_metadata() aten.add_.Tensor( input._data, other._data if isinstance(other, Index) else other, alpha=alpha, ) if isinstance(other, Tensor) and other.numel() <= 1: other = int(other) if isinstance(other, int): if dim_size is not None: input._dim_size = dim_size + alpha * other input._is_sorted = is_sorted elif isinstance(other, Index): if dim_size is not None and other.dim_size is not None: input._dim_size = dim_size + alpha * other.dim_size return input @implements(aten.sub.Tensor) def _sub( input: Union[int, Tensor, Index], other: Union[int, Tensor, Index], *, alpha: int = 1, ) -> Union[Index, Tensor]: data = aten.sub.Tensor( input._data if isinstance(input, Index) else input, other._data if isinstance(other, Index) else other, alpha=alpha, ) if data.dtype not in INDEX_DTYPES: return data if data.dim() != 1: return data out = Index(data) if not isinstance(input, Index): return out if isinstance(other, Tensor) and other.numel() <= 1: other = int(other) if isinstance(other, int): if input.dim_size is not None: out._dim_size = input.dim_size - alpha * other out._is_sorted = input.is_sorted return out @implements(aten.sub_.Tensor) def sub_( input: Index, other: Union[int, Tensor, Index], *, alpha: int = 1, ) -> Index: dim_size = input.dim_size is_sorted = input.is_sorted input._clear_metadata() aten.sub_.Tensor( input._data, other._data if isinstance(other, Index) else other, alpha=alpha, ) if isinstance(other, Tensor) and other.numel() <= 1: other = int(other) if isinstance(other, int): if dim_size is not None: input._dim_size = dim_size - alpha * other input._is_sorted = is_sorted return input