torch_geometric.edge_index.EdgeIndex

class EdgeIndex(data: Any, *args: Any, sparse_size: Optional[Tuple[Optional[int], Optional[int]]] = None, sort_order: Optional[Union[str, SortOrder]] = None, is_undirected: bool = False, **kwargs: Any)[source]

Bases: Tensor

A COO edge_index tensor with additional (meta)data attached.

EdgeIndex is a torch.Tensor, that holds an edge_index representation of shape [2, num_edges]. Edges are given as pairwise source and destination node indices in sparse COO format.

While EdgeIndex sub-classes a general torch.Tensor, it can hold additional (meta)data, i.e.:

Additionally, EdgeIndex caches data for fast CSR or CSC conversion in case its representation is sorted, such as its rowptr or colptr, or the permutation vector for going from CSR to CSC or vice versa. Caches are filled based on demand (e.g., when calling EdgeIndex.sort_by()), or when explicitly requested via EdgeIndex.fill_cache_(), and are maintained and adjusted over its lifespan (e.g., when calling EdgeIndex.flip()).

This representation ensures for optimal computation in GNN message passing schemes, while preserving the ease-of-use of regular COO-based workflows.

from torch_geometric import EdgeIndex

edge_index = EdgeIndex(
    [[0, 1, 1, 2],
     [1, 0, 2, 1]]
    sparse_size=(3, 3),
    sort_order='row',
    is_undirected=True,
    device='cpu',
)
>>> EdgeIndex([[0, 1, 1, 2],
...            [1, 0, 2, 1]])
assert edge_index.is_sorted_by_row
assert edge_index.is_undirected

# Flipping order:
edge_index = edge_index.flip(0)
>>> EdgeIndex([[1, 0, 2, 1],
...            [0, 1, 1, 2]])
assert edge_index.is_sorted_by_col
assert edge_index.is_undirected

# Filtering:
mask = torch.tensor([True, True, True, False])
edge_index = edge_index[:, mask]
>>> EdgeIndex([[1, 0, 2],
...            [0, 1, 1]])
assert edge_index.is_sorted_by_col
assert not edge_index.is_undirected

# Sparse-Dense Matrix Multiplication:
out = edge_index.flip(0) @ torch.randn(3, 16)
assert out.size() == (3, 16)
validate() EdgeIndex[source]

Validates the EdgeIndex representation.

In particular, it ensures that

  • it only holds valid indices.

  • the sort order is correctly set.

  • indices are bidirectional in case it is specified as undirected.

sparse_size() Tuple[Optional[int], Optional[int]][source]
sparse_size(dim: int) Optional[int]

The size of the underlying sparse matrix. If dim is specified, returns an integer holding the size of that sparse dimension.

Parameters:

dim (int, optional) – The dimension for which to retrieve the size. (default: None)

property num_rows: Optional[int]

The number of rows of the underlying sparse matrix.

property num_cols: Optional[int]

The number of columns of the underlying sparse matrix.

property sort_order: Optional[str]

The sort order of indices, either "row", "col" or None.

property is_sorted: bool

Returns whether indices are either sorted by rows or columns.

property is_sorted_by_row: bool

Returns whether indices are sorted by rows.

property is_sorted_by_col: bool

Returns whether indices are sorted by columns.

property is_undirected: bool

Returns whether indices are bidirectional.

get_sparse_size() Size[source]
get_sparse_size(dim: int) int

The size of the underlying sparse matrix. Automatically computed and cached when not explicitly set. If dim is specified, returns an integer holding the size of that sparse dimension.

Parameters:

dim (int, optional) – The dimension for which to retrieve the size. (default: None)

sparse_resize_(num_rows: Optional[int], num_cols: Optional[int]) EdgeIndex[source]

Assigns or re-assigns the size of the underlying sparse matrix.

Parameters:
  • num_rows (int, optional) – The number of rows.

  • num_cols (int, optional) – The number of columns.

get_num_rows() int[source]

The number of rows of the underlying sparse matrix. Automatically computed and cached when not explicitly set.

get_num_cols() int[source]

The number of columns of the underlying sparse matrix. Automatically computed and cached when not explicitly set.

get_indptr() Tensor[source]

Returns the compressed index representation in case EdgeIndex is sorted.

get_csr() Tuple[Tuple[Tensor, Tensor], Optional[Tensor]][source]

Returns the compressed CSR representation (rowptr, col), perm in case EdgeIndex is sorted.

get_csc() Tuple[Tuple[Tensor, Tensor], Optional[Tensor]][source]

Returns the compressed CSC representation (colptr, row), perm in case EdgeIndex is sorted.

fill_cache_(no_transpose: bool = False) EdgeIndex[source]

Fills the cache with (meta)data information.

Parameters:

no_transpose (bool, optional) – If set to True, will not fill the cache with information about the transposed EdgeIndex. (default: False)

share_memory_() EdgeIndex[source]

Moves the underlying storage to shared memory.

This is a no-op if the underlying storage is already in shared memory and for CUDA tensors. Tensors in shared memory cannot be resized.

is_shared() bool[source]

Checks if tensor is in shared memory.

This is always True for CUDA tensors.

as_tensor() Tensor[source]

Zero-copies the EdgeIndex representation back to a torch.Tensor representation.

sort_by(sort_order: Union[str, SortOrder], stable: bool = False) SortReturnType[source]

Sorts the elements by row or column indices.

Parameters:
  • sort_order (str) – The sort order, either "row" or "col".

  • stable (bool, optional) – Makes the sorting routine stable, which guarantees that the order of equivalent elements is preserved. (default: False)

to_dense(value: Optional[Tensor] = None, fill_value: float = 0.0, dtype: Optional[dtype] = None) Tensor[source]

Converts EdgeIndex into a dense torch.Tensor.

Warning

In case of duplicated edges, the behavior is non-deterministic (one of the values from value will be picked arbitrarily). For deterministic behavior, consider calling coalesce() beforehand.

Parameters:
  • value (torch.Tensor, optional) – The values for non-zero elements. If not specified, non-zero elements will be assigned a value of 1.0. (default: None)

  • fill_value (float, optional) – The fill value for remaining elements in the dense matrix. (default: 0.0)

  • dtype (torch.dtype, optional) – The data type of the returned tensor. (default: None)

to_sparse_coo(value: Optional[Tensor] = None) Tensor[source]

Converts EdgeIndex into a torch.sparse_coo_tensor.

Parameters:

value (torch.Tensor, optional) – The values for non-zero elements. If not specified, non-zero elements will be assigned a value of 1.0. (default: None)

to_sparse_csr(value: Optional[Tensor] = None) Tensor[source]

Converts EdgeIndex into a torch.sparse_csr_tensor.

Parameters:

value (torch.Tensor, optional) – The values for non-zero elements. If not specified, non-zero elements will be assigned a value of 1.0. (default: None)

to_sparse_csc(value: Optional[Tensor] = None) Tensor[source]

Converts EdgeIndex into a torch.sparse_csc_tensor.

Parameters:

value (torch.Tensor, optional) – The values for non-zero elements. If not specified, non-zero elements will be assigned a value of 1.0. (default: None)

to_sparse(*, layout: layout = torch.sparse_coo, value: Optional[Tensor] = None) Tensor[source]

Converts EdgeIndex into a torch.sparse tensor.

Parameters:
  • layout (torch.layout, optional) – The desired sparse layout. One of torch.sparse_coo, torch.sparse_csr, or torch.sparse_csc. (default: torch.sparse_coo)

  • value (torch.Tensor, optional) – The values for non-zero elements. If not specified, non-zero elements will be assigned a value of 1.0. (default: None)

to_sparse_tensor(value: Optional[Tensor] = None) SparseTensor[source]

Converts EdgeIndex into a torch_sparse.SparseTensor. Requires that torch-sparse is installed.

Parameters:

value (torch.Tensor, optional) – The values for non-zero elements. (default: None)

matmul(other: EdgeIndex, input_value: Optional[Tensor] = None, other_value: Optional[Tensor] = None, reduce: Literal['sum', 'mean', 'amin', 'amax', 'add', 'min', 'max'] = 'sum', transpose: bool = False) Tuple[EdgeIndex, Tensor][source]
matmul(other: Tensor, input_value: Optional[Tensor] = None, other_value: None = None, reduce: Literal['sum', 'mean', 'amin', 'amax', 'add', 'min', 'max'] = 'sum', transpose: bool = False) Tensor

Performs a matrix multiplication of the matrices input and other. If input is a \((n \times m)\) matrix and other is a \((m \times p)\) tensor, then the output will be a \((n \times p)\) tensor. See torch.matmul() for more information.

input is a sparse matrix as denoted by the indices in EdgeIndex, and input_value corresponds to the values of non-zero elements in input. If not specified, non-zero elements will be assigned a value of 1.0.

other can either be a dense torch.Tensor or a sparse EdgeIndex. if other is a sparse EdgeIndex, then other_value corresponds to the values of its non-zero elements.

This function additionally accepts an optional reduce argument that allows specification of an optional reduction operation. See torch.sparse.mm() for more information.

Lastly, the transpose option allows to perform matrix multiplication where input will be first transposed, i.e.:

\[\textrm{input}^{\top} \cdot \textrm{other}\]
Parameters:
  • other (torch.Tensor or EdgeIndex) – The second matrix to be multiplied, which can be sparse or dense.

  • input_value (torch.Tensor, optional) – The values for non-zero elements of input. If not specified, non-zero elements will be assigned a value of 1.0. (default: None)

  • other_value (torch.Tensor, optional) – The values for non-zero elements of other in case it is sparse. If not specified, non-zero elements will be assigned a value of 1.0. (default: None)

  • reduce (str, optional) – The reduce operation, one of "sum"/"add", "mean", "min"/amin or "max"/amax. (default: "sum")

  • transpose (bool, optional) – If set to True, will perform matrix multiplication based on the transposed input. (default: False)

sparse_narrow(dim: int, start: Union[int, Tensor], length: int) EdgeIndex[source]

Returns a new EdgeIndex that is a narrowed version of itself. Narrowing is performed by interpreting EdgeIndex as a sparse matrix of shape (num_rows, num_cols).

In contrast to torch.narrow(), the returned tensor does not share the same underlying storage anymore.

Parameters:
  • dim (int) – The dimension along which to narrow.

  • start (int or torch.Tensor) – Index of the element to start the narrowed dimension from.

  • length (int) – Length of the narrowed dimension.