torch_geometric.HashTensor

class HashTensor(key: Any, value: Optional[Any] = None, *, dtype: Optional[dtype] = None, device: Optional[device] = None)[source]

Bases: Tensor

A torch.Tensor that can be referenced by arbitrary keys rather than indices in the first dimension.

HashTensor sub-classes a general torch.Tensor, and extends it by CPU- and GPU-accelerated mapping routines. This allow for fast and efficient access to non-contiguous indices/keys while the underlying data is stored in a compact format.

This representation is ideal for scenarios where one needs a fast mapping routine without relying on CPU-based external packages, and can be used, e.g., to perform mapping of global indices to local indices during subgraph creation, or in data-processing pipelines to map non-contiguous input data into a contiguous space, such as

  • mapping of hashed node IDs to range [0, num_nodes - 1]

  • mapping of raw input data, e.g., categorical data to range [0, num_categories - 1]

Specifically, HashTensor supports any keys of any type, e.g., strings, timestamps, etc.

from torch_geometric import HashTensor

key = torch.tensor([1000, 100, 10000])
value = torch.randn(3, 4)

tensor = HashTensor(key, value)
assert tensor.size() == (3, 4)

# Filtering:
query = torch.tensor([10000, 1000])
out = tensor[query]
assert out.equal(value[[2, 0]])

# Accessing non-existing keys:
out = tensor[[10000, 0]]
out.isnan()
>>> tensor([[False, False, False, False],
...         [True, True, True, True])

# If `value` is not given, indexing returns the position of `query` in
# `key`, and `-1` otherwise:
key = ['Animation', 'Comedy', 'Fantasy']
tensor = HashTensor(key)

out = tensor[['Comedy', 'Romance']]
>>> tensor([1, -1])
Parameters:
  • key – The keys in the first dimension.

  • value – The values to hold.

  • dtype – The desired data type of the values of the returned tensor.

  • device – The device of the returned tensor.

as_tensor() Tensor[source]

Zero-copies the HashTensor representation back to a torch.Tensor representation.

Return type:

Tensor