Source code for torch_geometric.edge_index

import functools
import typing
from enum import Enum
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Literal,
    NamedTuple,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    Union,
    get_args,
    overload,
)

import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.typing import SparseTensor

HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}

if torch_geometric.typing.WITH_PT20:
    SUPPORTED_DTYPES: Set[torch.dtype] = {
        torch.int32,
        torch.int64,
    }
elif not typing.TYPE_CHECKING:  # pragma: no cover
    SUPPORTED_DTYPES: Set[torch.dtype] = {
        torch.int64,
    }

ReduceType = Literal['sum', 'mean', 'amin', 'amax', 'add', 'min', 'max']
PYG_REDUCE: Dict[ReduceType, ReduceType] = {
    'add': 'sum',
    'amin': 'min',
    'amax': 'max'
}
TORCH_REDUCE: Dict[ReduceType, ReduceType] = {
    'add': 'sum',
    'min': 'amin',
    'max': 'amax'
}


class SortOrder(Enum):
    ROW = 'row'
    COL = 'col'


class CatMetadata(NamedTuple):
    nnz: List[int]
    sparse_size: List[Tuple[Optional[int], Optional[int]]]
    sort_order: List[Optional[SortOrder]]
    is_undirected: List[bool]


def implements(torch_function: Callable) -> Callable:
    r"""Registers a :pytorch:`PyTorch` function override."""
    @functools.wraps(torch_function)
    def decorator(my_function: Callable) -> Callable:
        HANDLED_FUNCTIONS[torch_function] = my_function
        return my_function

    return decorator


def set_tuple_item(
    values: Tuple[Any, ...],
    dim: int,
    value: Any,
) -> Tuple[Any, ...]:
    if dim < -len(values) or dim >= len(values):
        raise IndexError("tuple index out of range")

    dim = dim + len(values) if dim < 0 else dim
    return values[:dim] + (value, ) + values[dim + 1:]


def maybe_add(
    value: Sequence[Optional[int]],
    other: Union[int, Sequence[Optional[int]]],
    alpha: int = 1,
) -> Tuple[Optional[int], ...]:

    if isinstance(other, int):
        return tuple(v + alpha * other if v is not None else None
                     for v in value)

    assert len(value) == len(other)
    return tuple(v + alpha * o if v is not None and o is not None else None
                 for v, o in zip(value, other))


def maybe_sub(
    value: Sequence[Optional[int]],
    other: Union[int, Sequence[Optional[int]]],
    alpha: int = 1,
) -> Tuple[Optional[int], ...]:

    if isinstance(other, int):
        return tuple(v - alpha * other if v is not None else None
                     for v in value)

    assert len(value) == len(other)
    return tuple(v - alpha * o if v is not None and o is not None else None
                 for v, o in zip(value, other))


def ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor:
    index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)
    return index.repeat_interleave(ptr.diff(), output_size=output_size)


def assert_valid_dtype(tensor: Tensor) -> None:
    if tensor.dtype not in SUPPORTED_DTYPES:
        raise ValueError(f"'EdgeIndex' holds an unsupported data type "
                         f"(got '{tensor.dtype}', but expected one of "
                         f"{SUPPORTED_DTYPES})")


def assert_two_dimensional(tensor: Tensor) -> None:
    if tensor.dim() != 2:
        raise ValueError(f"'EdgeIndex' needs to be two-dimensional "
                         f"(got {tensor.dim()} dimensions)")
    if not torch.jit.is_tracing() and tensor.size(0) != 2:
        raise ValueError(f"'EdgeIndex' needs to have a shape of "
                         f"[2, *] (got {list(tensor.size())})")


def assert_contiguous(tensor: Tensor) -> None:
    if not tensor.is_contiguous():
        raise ValueError("'EdgeIndex' needs to be contiguous. Please call "
                         "`edge_index.contiguous()` before proceeding.")


def assert_symmetric(size: Tuple[Optional[int], Optional[int]]) -> None:
    if (not torch.jit.is_tracing() and size[0] is not None
            and size[1] is not None and size[0] != size[1]):
        raise ValueError(f"'EdgeIndex' is undirected but received a "
                         f"non-symmetric size (got {list(size)})")


def assert_sorted(func: Callable) -> Callable:
    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        if not args[0].is_sorted:
            cls_name = args[0].__class__.__name__
            raise ValueError(
                f"Cannot call '{func.__name__}' since '{cls_name}' is not "
                f"sorted. Please call `{cls_name}.sort_by(...)` first.")
        return func(*args, **kwargs)

    return wrapper


[docs]class EdgeIndex(Tensor): r"""A COO :obj:`edge_index` tensor with additional (meta)data attached. :class:`EdgeIndex` is a :pytorch:`null` :class:`torch.Tensor`, that holds an :obj:`edge_index` representation of shape :obj:`[2, num_edges]`. Edges are given as pairwise source and destination node indices in sparse COO format. While :class:`EdgeIndex` sub-classes a general :pytorch:`null` :class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*: * :obj:`sparse_size`: The underlying sparse matrix size * :obj:`sort_order`: The sort order (if present), either by row or column. * :obj:`is_undirected`: Whether edges are bidirectional. Additionally, :class:`EdgeIndex` caches data for fast CSR or CSC conversion in case its representation is sorted, such as its :obj:`rowptr` or :obj:`colptr`, or the permutation vector for going from CSR to CSC or vice versa. Caches are filled based on demand (*e.g.*, when calling :meth:`EdgeIndex.sort_by`), or when explicitly requested via :meth:`EdgeIndex.fill_cache_`, and are maintained and adjusted over its lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`). This representation ensures for optimal computation in GNN message passing schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG` workflows. .. code-block:: python from torch_geometric import EdgeIndex edge_index = EdgeIndex( [[0, 1, 1, 2], [1, 0, 2, 1]] sparse_size=(3, 3), sort_order='row', is_undirected=True, device='cpu', ) >>> EdgeIndex([[0, 1, 1, 2], ... [1, 0, 2, 1]]) assert edge_index.is_sorted_by_row assert edge_index.is_undirected # Flipping order: edge_index = edge_index.flip(0) >>> EdgeIndex([[1, 0, 2, 1], ... [0, 1, 1, 2]]) assert edge_index.is_sorted_by_col assert edge_index.is_undirected # Filtering: mask = torch.tensor([True, True, True, False]) edge_index = edge_index[:, mask] >>> EdgeIndex([[1, 0, 2], ... [0, 1, 1]]) assert edge_index.is_sorted_by_col assert not edge_index.is_undirected # Sparse-Dense Matrix Multiplication: out = edge_index.flip(0) @ torch.randn(3, 16) assert out.size() == (3, 16) """ # See "https://pytorch.org/docs/stable/notes/extending.html" # for a basic tutorial on how to subclass `torch.Tensor`. # The size of the underlying sparse matrix: _sparse_size: Tuple[Optional[int], Optional[int]] = (None, None) # Whether the `edge_index` represented is non-sorted (`None`), or sorted # based on row or column values. _sort_order: Optional[SortOrder] = None # Whether the `edge_index` is undirected: # NOTE `is_undirected` allows us to assume symmetric adjacency matrix size # and to share compressed pointer representations, however, it does not # allow us get rid of CSR/CSC permutation vectors since ordering within # neighborhoods is not necessarily deterministic. _is_undirected: bool = False # A cache for its compressed representation: _indptr: Optional[Tensor] = None # A cache for its transposed representation: _T_perm: Optional[Tensor] = None _T_index: Tuple[Optional[Tensor], Optional[Tensor]] = (None, None) _T_indptr: Optional[Tensor] = None # A cached "1"-value vector for `torch.sparse` matrix multiplication: _value: Optional[Tensor] = None # Whenever we perform a concatenation of edge indices, we cache the # original metadata to be able to reconstruct individual edge indices: _cat_metadata: Optional[CatMetadata] = None def __new__( cls: Type, data: Any, *args: Any, sparse_size: Optional[Tuple[Optional[int], Optional[int]]] = None, sort_order: Optional[Union[str, SortOrder]] = None, is_undirected: bool = False, **kwargs: Any, ) -> 'EdgeIndex': if not isinstance(data, Tensor): data = torch.tensor(data, *args, **kwargs) elif len(args) > 0: raise TypeError( f"new() received an invalid combination of arguments - got " f"(Tensor, {', '.join(str(type(arg)) for arg in args)})") elif len(kwargs) > 0: raise TypeError(f"new() received invalid keyword arguments - got " f"{set(kwargs.keys())})") assert isinstance(data, Tensor) indptr: Optional[Tensor] = None if isinstance(data, cls): # If passed `EdgeIndex`, inherit metadata: indptr = data._indptr sparse_size = sparse_size or data.sparse_size() sort_order = sort_order or data.sort_order is_undirected = is_undirected or data.is_undirected # Convert `torch.sparse` tensors to `EdgeIndex` representation: if data.layout == torch.sparse_coo: sort_order = SortOrder.ROW sparse_size = sparse_size or (data.size(0), data.size(1)) data = data.indices() if data.layout == torch.sparse_csr: indptr = data.crow_indices() col = data.col_indices() assert isinstance(indptr, Tensor) row = ptr2index(indptr, output_size=col.numel()) sort_order = SortOrder.ROW sparse_size = sparse_size or (data.size(0), data.size(1)) if sparse_size[0] is not None and sparse_size[0] != data.size(0): indptr = None data = torch.stack([row, col], dim=0) if (torch_geometric.typing.WITH_PT112 and data.layout == torch.sparse_csc): row = data.row_indices() indptr = data.ccol_indices() assert isinstance(indptr, Tensor) col = ptr2index(indptr, output_size=row.numel()) sort_order = SortOrder.COL sparse_size = sparse_size or (data.size(0), data.size(1)) if sparse_size[1] is not None and sparse_size[1] != data.size(1): indptr = None data = torch.stack([row, col], dim=0) assert_valid_dtype(data) assert_two_dimensional(data) assert_contiguous(data) if sparse_size is None: sparse_size = (None, None) if is_undirected: assert_symmetric(sparse_size) if sparse_size[0] is not None and sparse_size[1] is None: sparse_size = (sparse_size[0], sparse_size[0]) elif sparse_size[0] is None and sparse_size[1] is not None: sparse_size = (sparse_size[1], sparse_size[1]) if torch_geometric.typing.WITH_PT112: out = super().__new__(cls, data) else: out = Tensor._make_subclass(cls, data) # Attach metadata: assert isinstance(out, EdgeIndex) out._sparse_size = sparse_size out._sort_order = None if sort_order is None else SortOrder(sort_order) out._is_undirected = is_undirected out._indptr = indptr if isinstance(data, cls): # If passed `EdgeIndex`, inherit metadata: out._T_perm = data._T_perm out._T_index = data._T_index out._T_indptr = data._T_indptr out._value = out._value # Reset metadata if cache is invalidated: num_rows = sparse_size[0] if num_rows is not None and num_rows != data.sparse_size(0): out._indptr = None num_cols = sparse_size[1] if num_cols is not None and num_cols != data.sparse_size(1): out._T_indptr = None return out # Validation ##############################################################
[docs] def validate(self) -> 'EdgeIndex': r"""Validates the :class:`EdgeIndex` representation. In particular, it ensures that * it only holds valid indices. * the sort order is correctly set. * indices are bidirectional in case it is specified as undirected. """ assert_valid_dtype(self) assert_two_dimensional(self) assert_contiguous(self) if self.is_undirected: assert_symmetric(self.sparse_size()) if self.numel() > 0 and self.min() < 0: raise ValueError(f"'{self.__class__.__name__}' contains negative " f"indices (got {int(self.min())})") if (self.numel() > 0 and self.num_rows is not None and self[0].max() >= self.num_rows): raise ValueError(f"'{self.__class__.__name__}' contains larger " f"indices than its number of rows " f"(got {int(self[0].max())}, but expected values " f"smaller than {self.num_rows})") if (self.numel() > 0 and self.num_cols is not None and self[1].max() >= self.num_cols): raise ValueError(f"'{self.__class__.__name__}' contains larger " f"indices than its number of columns " f"(got {int(self[1].max())}, but expected values " f"smaller than {self.num_cols})") if self.is_sorted_by_row and (self[0].diff() < 0).any(): raise ValueError(f"'{self.__class__.__name__}' is not sorted by " f"row indices") if self.is_sorted_by_col and (self[1].diff() < 0).any(): raise ValueError(f"'{self.__class__.__name__}' is not sorted by " f"column indices") if self.is_undirected: flat_index1 = (self[0] * self.get_num_rows() + self[1]).sort()[0] flat_index2 = (self[1] * self.get_num_cols() + self[0]).sort()[0] if not torch.equal(flat_index1, flat_index2): raise ValueError(f"'{self.__class__.__name__}' is not " f"undirected") return self
# Properties ############################################################## @overload def sparse_size(self) -> Tuple[Optional[int], Optional[int]]: pass @overload def sparse_size(self, dim: int) -> Optional[int]: pass
[docs] def sparse_size( self, dim: Optional[int] = None, ) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]: r"""The size of the underlying sparse matrix. If :obj:`dim` is specified, returns an integer holding the size of that sparse dimension. Args: dim (int, optional): The dimension for which to retrieve the size. (default: :obj:`None`) """ if dim is not None: return self._sparse_size[dim] return self._sparse_size
@property def num_rows(self) -> Optional[int]: r"""The number of rows of the underlying sparse matrix.""" return self._sparse_size[0] @property def num_cols(self) -> Optional[int]: r"""The number of columns of the underlying sparse matrix.""" return self._sparse_size[1] @property def sort_order(self) -> Optional[str]: r"""The sort order of indices, either :obj:`"row"`, :obj:`"col"` or :obj:`None`. """ return None if self._sort_order is None else self._sort_order.value @property def is_sorted(self) -> bool: r"""Returns whether indices are either sorted by rows or columns.""" return self._sort_order is not None @property def is_sorted_by_row(self) -> bool: r"""Returns whether indices are sorted by rows.""" return self._sort_order == SortOrder.ROW @property def is_sorted_by_col(self) -> bool: r"""Returns whether indices are sorted by columns.""" return self._sort_order == SortOrder.COL @property def is_undirected(self) -> bool: r"""Returns whether indices are bidirectional.""" return self._is_undirected # Cache Interface ######################################################### @overload def get_sparse_size(self) -> torch.Size: pass @overload def get_sparse_size(self, dim: int) -> int: pass
[docs] def get_sparse_size( self, dim: Optional[int] = None, ) -> Union[torch.Size, int]: r"""The size of the underlying sparse matrix. Automatically computed and cached when not explicitly set. If :obj:`dim` is specified, returns an integer holding the size of that sparse dimension. Args: dim (int, optional): The dimension for which to retrieve the size. (default: :obj:`None`) """ if dim is not None: size = self._sparse_size[dim] if size is not None: return size if self.is_undirected: size = int(self.max()) + 1 if self.numel() > 0 else 0 self._sparse_size = (size, size) return size size = int(self[dim].max()) + 1 if self.numel() > 0 else 0 self._sparse_size = set_tuple_item(self._sparse_size, dim, size) return size return torch.Size((self.get_sparse_size(0), self.get_sparse_size(1)))
[docs] def get_num_rows(self) -> int: r"""The number of rows of the underlying sparse matrix. Automatically computed and cached when not explicitly set. """ return self.get_sparse_size(0)
[docs] def get_num_cols(self) -> int: r"""The number of columns of the underlying sparse matrix. Automatically computed and cached when not explicitly set. """ return self.get_sparse_size(1)
[docs] @assert_sorted def get_indptr(self) -> Tensor: r"""Returns the compressed index representation in case :class:`EdgeIndex` is sorted. """ if self._indptr is not None: return self._indptr if self.is_undirected and self._T_indptr is not None: return self._T_indptr dim = 0 if self.is_sorted_by_row else 1 self._indptr = torch._convert_indices_from_coo_to_csr( self[dim], self.get_sparse_size(dim), out_int32=self.dtype != torch.int64, ) return self._indptr
@assert_sorted def _sort_by_transpose(self) -> Tuple[Tuple[Tensor, Tensor], Tensor]: from torch_geometric.utils import index_sort dim = 1 if self.is_sorted_by_row else 0 if self._T_perm is None: index, perm = index_sort(self[dim], self.get_sparse_size(dim)) self._T_index = set_tuple_item(self._T_index, dim, index) self._T_perm = perm if self._T_index[1 - dim] is None: self._T_index = set_tuple_item( # self._T_index, 1 - dim, self[1 - dim][self._T_perm]) row, col = self._T_index assert row is not None and col is not None return (row, col), self._T_perm
[docs] @assert_sorted def get_csr(self) -> Tuple[Tuple[Tensor, Tensor], Union[Tensor, slice]]: r"""Returns the compressed CSR representation :obj:`(rowptr, col), perm` in case :class:`EdgeIndex` is sorted. """ if self.is_sorted_by_row: return (self.get_indptr(), self[1]), slice(None, None, None) assert self.is_sorted_by_col (row, col), perm = self._sort_by_transpose() if self._T_indptr is not None: rowptr = self._T_indptr elif self.is_undirected and self._indptr is not None: rowptr = self._indptr else: rowptr = self._T_indptr = torch._convert_indices_from_coo_to_csr( row, self.get_num_rows(), out_int32=self.dtype != torch.int64, ) return (rowptr, col), perm
[docs] @assert_sorted def get_csc(self) -> Tuple[Tuple[Tensor, Tensor], Union[Tensor, slice]]: r"""Returns the compressed CSC representation :obj:`(colptr, row), perm` in case :class:`EdgeIndex` is sorted. """ if self.is_sorted_by_col: return (self.get_indptr(), self[0]), slice(None, None, None) assert self.is_sorted_by_row (row, col), perm = self._sort_by_transpose() if self._T_indptr is not None: colptr = self._T_indptr elif self.is_undirected and self._indptr is not None: colptr = self._indptr else: colptr = self._T_indptr = torch._convert_indices_from_coo_to_csr( col, self.get_num_cols(), out_int32=self.dtype != torch.int64, ) return (colptr, row), perm
def _get_value(self, dtype: Optional[torch.dtype] = None) -> Tensor: if self._value is not None: if (dtype or torch.get_default_dtype()) == self._value.dtype: return self._value # Expanded tensors are not yet supported in all PyTorch code paths :( # value = torch.ones(1, dtype=dtype, device=self.device) # value = value.expand(self.size(1)) self._value = torch.ones(self.size(1), dtype=dtype, device=self.device) return self._value
[docs] def fill_cache_(self, no_transpose: bool = False) -> 'EdgeIndex': r"""Fills the cache with (meta)data information. Args: no_transpose (bool, optional): If set to :obj:`True`, will not fill the cache with information about the transposed :class:`EdgeIndex`. (default: :obj:`False`) """ self.get_sparse_size() if self.is_sorted_by_row: self.get_csr() if not no_transpose: self.get_csc() elif self.is_sorted_by_col: self.get_csc() if not no_transpose: self.get_csr() return self
# Methods #################################################################
[docs] def as_tensor(self) -> Tensor: r"""Zero-copies the :class:`EdgeIndex` representation back to a :class:`torch.Tensor` representation. """ return self.as_subclass(Tensor)
[docs] def sort_by( self, sort_order: Union[str, SortOrder], stable: bool = False, ) -> 'SortReturnType': r"""Sorts the elements by row or column indices. Args: sort_order (str): The sort order, either :obj:`"row"` or :obj:`"col"`. stable (bool, optional): Makes the sorting routine stable, which guarantees that the order of equivalent elements is preserved. (default: :obj:`False`) """ from torch_geometric.utils import index_sort sort_order = SortOrder(sort_order) if self._sort_order == sort_order: # Nothing to do. return SortReturnType(self, slice(None, None, None)) if self.is_sorted: (row, col), perm = self._sort_by_transpose() edge_index = torch.stack([row, col], dim=0) # Otherwise, perform sorting: elif sort_order == SortOrder.ROW: row, perm = index_sort(self[0], self.get_num_rows(), stable) edge_index = torch.stack([row, self[1][perm]], dim=0) else: col, perm = index_sort(self[1], self.get_num_cols(), stable) edge_index = torch.stack([self[0][perm], col], dim=0) out = self.__class__(edge_index) # We can inherit metadata and (mostly) cache: out._sparse_size = self.sparse_size() out._sort_order = sort_order out._is_undirected = self.is_undirected out._indptr = self._indptr out._T_indptr = self._T_indptr # NOTE We cannot copy CSR<>CSC permutations since we don't require that # local neighborhoods are sorted, and thus they may run out of sync. out._value = self._value return SortReturnType(out, perm)
[docs] def to_dense( # type: ignore self, value: Optional[Tensor] = None, fill_value: float = 0.0, dtype: Optional[torch.dtype] = None, ) -> Tensor: r"""Converts :class:`EdgeIndex` into a dense :class:`torch.Tensor`. .. warning:: In case of duplicated edges, the behavior is non-deterministic (one of the values from :obj:`value` will be picked arbitrarily). For deterministic behavior, consider calling :meth:`~torch_geometric.utils.coalesce` beforehand. Args: value (torch.Tensor, optional): The values for non-zero elements. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) fill_value (float, optional): The fill value for remaining elements in the dense matrix. (default: :obj:`0.0`) dtype (torch.dtype, optional): The data type of the returned tensor. (default: :obj:`None`) """ dtype = value.dtype if value is not None else dtype size = self.get_sparse_size() if value is not None and value.dim() > 1: size = size + value.size()[1:] # type: ignore out = torch.full(size, fill_value, dtype=dtype, device=self.device) out[self[0], self[1]] = value if value is not None else 1 return out
[docs] def to_sparse_coo(self, value: Optional[Tensor] = None) -> Tensor: r"""Converts :class:`EdgeIndex` into a :pytorch:`null` :class:`torch.sparse_coo_tensor`. Args: value (torch.Tensor, optional): The values for non-zero elements. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) """ value = self._get_value() if value is None else value out = torch.sparse_coo_tensor( indices=self.as_tensor(), values=value, size=self.get_sparse_size(), device=self.device, requires_grad=value.requires_grad, ) if self.is_sorted_by_row: out = out._coalesced_(True) return out
[docs] def to_sparse_csr( # type: ignore self, value: Optional[Tensor] = None, ) -> Tensor: r"""Converts :class:`EdgeIndex` into a :pytorch:`null` :class:`torch.sparse_csr_tensor`. Args: value (torch.Tensor, optional): The values for non-zero elements. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) """ (rowptr, col), perm = self.get_csr() value = self._get_value() if value is None else value[perm] return torch.sparse_csr_tensor( crow_indices=rowptr, col_indices=col, values=value, size=self.get_sparse_size(), device=self.device, requires_grad=value.requires_grad, )
[docs] def to_sparse_csc( # type: ignore self, value: Optional[Tensor] = None, ) -> Tensor: r"""Converts :class:`EdgeIndex` into a :pytorch:`null` :class:`torch.sparse_csc_tensor`. Args: value (torch.Tensor, optional): The values for non-zero elements. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) """ if not torch_geometric.typing.WITH_PT112: raise NotImplementedError( "'to_sparse_csc' not supported for PyTorch < 1.12") (colptr, row), perm = self.get_csc() value = self._get_value() if value is None else value[perm] return torch.sparse_csc_tensor( ccol_indices=colptr, row_indices=row, values=value, size=self.get_sparse_size(), device=self.device, requires_grad=value.requires_grad, )
[docs] def to_sparse( # type: ignore self, *, layout: torch.layout = torch.sparse_coo, value: Optional[Tensor] = None, ) -> Tensor: r"""Converts :class:`EdgeIndex` into a :pytorch:`null` :class:`torch.sparse` tensor. Args: layout (torch.layout, optional): The desired sparse layout. One of :obj:`torch.sparse_coo`, :obj:`torch.sparse_csr`, or :obj:`torch.sparse_csc`. (default: :obj:`torch.sparse_coo`) value (torch.Tensor, optional): The values for non-zero elements. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) """ if layout is None or layout == torch.sparse_coo: return self.to_sparse_coo(value) if layout == torch.sparse_csr: return self.to_sparse_csr(value) if torch_geometric.typing.WITH_PT112 and layout == torch.sparse_csc: return self.to_sparse_csc(value) raise ValueError(f"Unexpected tensor layout (got '{layout}')")
[docs] def to_sparse_tensor( self, value: Optional[Tensor] = None, ) -> SparseTensor: r"""Converts :class:`EdgeIndex` into a :class:`torch_sparse.SparseTensor`. Requires that :obj:`torch-sparse` is installed. Args: value (torch.Tensor, optional): The values for non-zero elements. (default: :obj:`None`) """ return SparseTensor( row=self[0], col=self[1], rowptr=self._indptr if self.is_sorted_by_row else None, value=value, sparse_sizes=self.get_sparse_size(), is_sorted=self.is_sorted_by_row, trust_data=True, )
# TODO investigate how to avoid overlapping return types here. @overload def matmul( # type: ignore self, other: 'EdgeIndex', input_value: Optional[Tensor] = None, other_value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Tuple['EdgeIndex', Tensor]: pass @overload def matmul( self, other: Tensor, input_value: Optional[Tensor] = None, other_value: None = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Tensor: pass
[docs] def matmul( self, other: Union[Tensor, 'EdgeIndex'], input_value: Optional[Tensor] = None, other_value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Union[Tensor, Tuple['EdgeIndex', Tensor]]: r"""Performs a matrix multiplication of the matrices :obj:`input` and :obj:`other`. If :obj:`input` is a :math:`(n \times m)` matrix and :obj:`other` is a :math:`(m \times p)` tensor, then the output will be a :math:`(n \times p)` tensor. See :meth:`torch.matmul` for more information. :obj:`input` is a sparse matrix as denoted by the indices in :class:`EdgeIndex`, and :obj:`input_value` corresponds to the values of non-zero elements in :obj:`input`. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. :obj:`other` can either be a dense :class:`torch.Tensor` or a sparse :class:`EdgeIndex`. if :obj:`other` is a sparse :class:`EdgeIndex`, then :obj:`other_value` corresponds to the values of its non-zero elements. This function additionally accepts an optional :obj:`reduce` argument that allows specification of an optional reduction operation. See :meth:`torch.sparse.mm` for more information. Lastly, the :obj:`transpose` option allows to perform matrix multiplication where :obj:`input` will be first transposed, *i.e.*: .. math:: \textrm{input}^{\top} \cdot \textrm{other} Args: other (torch.Tensor or EdgeIndex): The second matrix to be multiplied, which can be sparse or dense. input_value (torch.Tensor, optional): The values for non-zero elements of :obj:`input`. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) other_value (torch.Tensor, optional): The values for non-zero elements of :obj:`other` in case it is sparse. If not specified, non-zero elements will be assigned a value of :obj:`1.0`. (default: :obj:`None`) reduce (str, optional): The reduce operation, one of :obj:`"sum"`/:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`/:obj:`amin` or :obj:`"max"`/:obj:`amax`. (default: :obj:`"sum"`) transpose (bool, optional): If set to :obj:`True`, will perform matrix multiplication based on the transposed :obj:`input`. (default: :obj:`False`) """ return matmul(self, other, input_value, other_value, reduce, transpose)
[docs] def sparse_narrow( self, dim: int, start: Union[int, Tensor], length: int, ) -> 'EdgeIndex': r"""Returns a new :class:`EdgeIndex` that is a narrowed version of itself. Narrowing is performed by interpreting :class:`EdgeIndex` as a sparse matrix of shape :obj:`(num_rows, num_cols)`. In contrast to :meth:`torch.narrow`, the returned tensor does not share the same underlying storage anymore. Args: dim (int): The dimension along which to narrow. start (int or torch.Tensor): Index of the element to start the narrowed dimension from. length (int): Length of the narrowed dimension. """ dim = dim + 2 if dim < 0 else dim if dim != 0 and dim != 1: raise ValueError(f"Expected dimension to be 0 or 1 (got {dim})") if start < 0: raise ValueError(f"Expected 'start' value to be positive " f"(got {start})") if dim == 0: (rowptr, col), _ = self.get_csr() rowptr = rowptr.narrow(0, start, length + 1) if rowptr.numel() < 2: row, col = self[0, :0], self[1, :0] rowptr = None num_rows = 0 else: col = col[rowptr[0]:rowptr[-1]] rowptr = rowptr - rowptr[0] num_rows = rowptr.numel() - 1 row = torch.arange( num_rows, dtype=col.dtype, device=col.device, ).repeat_interleave( rowptr.diff(), output_size=col.numel(), ) edge_index = EdgeIndex( torch.stack([row, col], dim=0), sparse_size=(num_rows, self.sparse_size(1)), sort_order='row', ) edge_index._indptr = rowptr return edge_index else: # dim == 0: (colptr, row), _ = self.get_csc() colptr = colptr.narrow(0, start, length + 1) if colptr.numel() < 2: row, col = self[0, :0], self[1, :0] colptr = None num_cols = 0 else: row = row[colptr[0]:colptr[-1]] colptr = colptr - colptr[0] num_cols = colptr.numel() - 1 col = torch.arange( num_cols, dtype=row.dtype, device=row.device, ).repeat_interleave( colptr.diff(), output_size=row.numel(), ) edge_index = EdgeIndex( torch.stack([row, col], dim=0), sparse_size=(self.sparse_size(0), num_cols), sort_order='col', ) edge_index._indptr = colptr return edge_index
@classmethod def __torch_function__( cls: Type, func: Callable, types: Tuple[Type, ...], args: Tuple[Any, ...] = (), kwargs: Optional[Dict[str, Any]] = None, ) -> Any: # `EdgeIndex` should be treated as a regular PyTorch tensor for all # standard PyTorch functionalities. However, # * some of its metadata can be transferred to new functions, e.g., # `torch.cat(dim=1)` can inherit the sparse matrix size, or # `torch.narrow(dim=1)` can inherit cached pointers. # * not all operations lead to valid `EdgeIndex` tensors again, e.g., # `torch.sum()` does not yield a `EdgeIndex` as its output, or # `torch.cat(dim=0) violates the [2, *] shape assumption. # To account for this, we hold a number of `HANDLED_FUNCTIONS` that # implement specific functions for valid `EdgeIndex` routines. if func in HANDLED_FUNCTIONS: return HANDLED_FUNCTIONS[func](*args, **(kwargs or {})) # For all other PyTorch functions, we return a vanilla PyTorch tensor. _types = tuple(Tensor if issubclass(t, cls) else t for t in types) return Tensor.__torch_function__(func, _types, args, kwargs)
class SortReturnType(NamedTuple): values: EdgeIndex indices: Union[Tensor, slice] @implements(Tensor.__repr__) def __repr__( tensor: EdgeIndex, *, tensor_contents: Optional[str] = None, ) -> str: # Monkey-patch `torch._tensor_str._add_suffixes`. There might exist better # solutions to attach additional metadata, but this seems to be the most # straightforward one to inherit most of the `torch.Tensor` print logic: orig_fn = torch._tensor_str._add_suffixes def _add_suffixes( tensor_str: str, suffixes: List[str], indent: int, force_newline: bool, ) -> str: num_rows, num_cols = tensor.sparse_size() if num_rows is not None or num_cols is not None: size_repr = f"({num_rows or '?'}, {num_cols or '?'})" suffixes.append(f'sparse_size={size_repr}') suffixes.append(f'nnz={tensor.size(1)}') if tensor.is_sorted: suffixes.append(f'sort_order={tensor.sort_order}') if tensor.is_undirected: suffixes.append('is_undirected=True') return orig_fn(tensor_str, suffixes, indent, force_newline) torch._tensor_str._add_suffixes = _add_suffixes out = torch._tensor_str._str(tensor, tensor_contents=tensor_contents) torch._tensor_str._add_suffixes = orig_fn return out def apply_( tensor: EdgeIndex, fn: Callable, *args: Any, **kwargs: Any, ) -> EdgeIndex: out = Tensor.__torch_function__(fn, (Tensor, ), (tensor, ) + args, kwargs) out = out.as_subclass(EdgeIndex) # Copy metadata: out._sparse_size = tensor.sparse_size() out._sort_order = tensor._sort_order out._is_undirected = tensor._is_undirected # Convert cache (but do not consider `_value`): if tensor._indptr is not None: out._indptr = fn(tensor._indptr, *args, **kwargs) if tensor._T_perm is not None: out._T_perm = fn(tensor._T_perm, *args, **kwargs) _T_row, _T_col = tensor._T_index if _T_row is not None: _T_row = fn(_T_row, *args, **kwargs) if _T_col is not None: _T_col = fn(_T_col, *args, **kwargs) out._T_index = (_T_row, _T_col) if tensor._T_indptr is not None: out._T_indptr = fn(tensor._T_indptr, *args, **kwargs) return out @implements(torch.clone) @implements(Tensor.clone) def clone(tensor: EdgeIndex) -> EdgeIndex: return apply_(tensor, Tensor.clone) @implements(Tensor.to) def to( tensor: EdgeIndex, *args: Any, **kwargs: Any, ) -> Union[EdgeIndex, Tensor]: out = apply_(tensor, Tensor.to, *args, **kwargs) return out if out.dtype in SUPPORTED_DTYPES else out.as_tensor() @implements(Tensor.int) def _int(tensor: EdgeIndex) -> EdgeIndex: return to(tensor, torch.int32) @implements(Tensor.long) def long(tensor: EdgeIndex, *args: Any, **kwargs: Any) -> EdgeIndex: return to(tensor, torch.int64) @implements(Tensor.cpu) def cpu(tensor: EdgeIndex, *args: Any, **kwargs: Any) -> EdgeIndex: return apply_(tensor, Tensor.cpu, *args, **kwargs) @implements(Tensor.cuda) def cuda( # pragma: no cover tensor: EdgeIndex, *args: Any, **kwargs: Any, ) -> EdgeIndex: return apply_(tensor, Tensor.cuda, *args, **kwargs) @implements(Tensor.share_memory_) def share_memory_(tensor: EdgeIndex) -> EdgeIndex: return apply_(tensor, Tensor.share_memory_) @implements(Tensor.contiguous) def contiguous(tensor: EdgeIndex) -> EdgeIndex: return apply_(tensor, Tensor.contiguous) @implements(torch.cat) def cat( tensors: List[Union[EdgeIndex, Tensor]], dim: int = 0, *, out: Optional[Tensor] = None, ) -> Union[EdgeIndex, Tensor]: if len(tensors) == 1: return tensors[0] output = Tensor.__torch_function__(torch.cat, (Tensor, ), (tensors, dim), dict(out=out)) if dim != 1 and dim != -1: # No valid `EdgeIndex` anymore. return output if any([not isinstance(tensor, EdgeIndex) for tensor in tensors]): return output output = output.as_subclass(EdgeIndex) nnz_list = [t.size(1) for t in tensors] sparse_size_list = [t.sparse_size() for t in tensors] # type: ignore sort_order_list = [t._sort_order for t in tensors] # type: ignore is_undirected_list = [t.is_undirected for t in tensors] # type: ignore # Post-process `sparse_size`: total_num_rows: Optional[int] = 0 for num_rows, _ in sparse_size_list: if num_rows is None: total_num_rows = None break assert isinstance(total_num_rows, int) total_num_rows = max(num_rows, total_num_rows) total_num_cols: Optional[int] = 0 for _, num_cols in sparse_size_list: if num_cols is None: total_num_cols = None break assert isinstance(total_num_cols, int) num_cols = max(num_cols, total_num_cols) output._sparse_size = (num_rows, num_cols) # Post-process `is_undirected`: output._is_undirected = all(is_undirected_list) output._cat_metadata = CatMetadata( nnz=nnz_list, sparse_size=sparse_size_list, sort_order=sort_order_list, is_undirected=is_undirected_list, ) return output @implements(torch.flip) @implements(Tensor.flip) def flip( input: EdgeIndex, dims: Union[int, List[int], Tuple[int, ...]], ) -> Union[EdgeIndex, Tensor]: if isinstance(dims, int): dims = [dims] assert isinstance(dims, (tuple, list)) out = Tensor.__torch_function__(torch.flip, (Tensor, ), (input, dims)) out = out.as_subclass(EdgeIndex) out._value = input._value out._is_undirected = input.is_undirected # Flip metadata and cache: if 0 in dims or -2 in dims: out._sparse_size = input.sparse_size()[::-1] if len(dims) == 1 and (dims[0] == 0 or dims[0] == -2): if input.is_sorted_by_row: out._sort_order = SortOrder.COL elif input.is_sorted_by_col: out._sort_order = SortOrder.ROW out._indptr = input._T_indptr out._T_perm = input._T_perm out._T_index = input._T_index[::-1] out._T_indptr = input._indptr return out @implements(torch.index_select) @implements(Tensor.index_select) def index_select( input: EdgeIndex, dim: int, index: Tensor, *, out: Optional[Tensor] = None, ) -> Union[EdgeIndex, Tensor]: output = Tensor.__torch_function__( # torch.index_select, (Tensor, ), (input, dim, index), dict(out=out)) if dim == 1 or dim == -1: output = output.as_subclass(EdgeIndex) output._sparse_size = input.sparse_size() return output @implements(torch.narrow) @implements(Tensor.narrow) def narrow( input: EdgeIndex, dim: int, start: Union[int, Tensor], length: int, ) -> Union[EdgeIndex, Tensor]: out = Tensor.__torch_function__( # torch.narrow, (Tensor, ), (input, dim, start, length)) if dim == 1 or dim == -1: out = out.as_subclass(EdgeIndex) out._sparse_size = input.sparse_size() # NOTE We could potentially maintain `rowptr`/`colptr` attributes here, # but it is not really clear if this is worth it. The most important # information, the sort order, needs to be maintained though: out._sort_order = input._sort_order return out @implements(Tensor.__getitem__) def getitem(input: EdgeIndex, index: Any) -> Union[EdgeIndex, Tensor]: out = Tensor.__torch_function__( # Tensor.__getitem__, (Tensor, ), (input, index)) # There exists 3 possible index types that map back to a valid `EdgeIndex`, # and all include selecting/filtering in the last dimension only: def is_last_dim_select(i: Any) -> bool: # Maps to true for `__getitem__` requests of the form # `tensor[..., index]` or `tensor[:, index]`. if not isinstance(i, tuple) or len(i) != 2: return False if i[0] == Ellipsis: return True if not isinstance(i[0], slice): return False return i[0].start is None and i[0].stop is None and i[0].step is None is_valid = is_last_dim_select(index) # 1. `edge_index[:, mask]` or `edge_index[..., mask]`. if (is_valid and isinstance(index[1], Tensor) and index[1].dtype in (torch.bool, torch.uint8)): out = out.as_subclass(EdgeIndex) out._sparse_size = input.sparse_size() out._sort_order = input._sort_order # 2. `edge_index[:, index]` or `edge_index[..., index]`. elif is_valid and isinstance(index[1], Tensor): out = out.as_subclass(EdgeIndex) out._sparse_size = input.sparse_size() # 3. `edge_index[:, slice]` or `edge_index[..., slice]`. elif is_valid and isinstance(index[1], slice): out = out.as_subclass(EdgeIndex) out._sparse_size = input.sparse_size() if index[1].step is None or index[1].step > 0: out._sort_order = input._sort_order return out def postprocess_add_( input: EdgeIndex, other: Union[int, Tensor], out: Tensor, alpha: int = 1, ) -> Union[EdgeIndex, Tensor]: if out.dtype not in SUPPORTED_DTYPES: return out if out.dim() != 2 or out.size(0) != 2: return out output: EdgeIndex = out.as_subclass(EdgeIndex) if isinstance(other, int): size = maybe_add(input._sparse_size, other, alpha) assert len(size) == 2 output._sparse_size = size output._sort_order = input._sort_order output._is_undirected = input.is_undirected output._T_perm = input._T_perm elif isinstance(other, Tensor) and other.numel() <= 1: size = maybe_add(input._sparse_size, int(other), alpha) assert len(size) == 2 output._sparse_size = size output._sort_order = input._sort_order output._is_undirected = input.is_undirected output._T_perm = input._T_perm elif isinstance(other, Tensor) and other.size() == (2, 1): size = maybe_add(input._sparse_size, other.view(-1).tolist(), alpha) assert len(size) == 2 output._sparse_size = size output._sort_order = input._sort_order output._T_perm = input._T_perm if torch.equal(other[0], other[1]): output._is_undirected = input.is_undirected elif isinstance(other, EdgeIndex): size = maybe_add(input._sparse_size, other._sparse_size, alpha) assert len(size) == 2 output._sparse_size = size return output @implements(torch.add) @implements(Tensor.add) def add( input: EdgeIndex, other: Union[int, Tensor], *, alpha: int = 1, out: Optional[Tensor] = None, ) -> Union[EdgeIndex, Tensor]: output = Tensor.__torch_function__( # torch.add, (Tensor, ), (input, other), dict(alpha=alpha, out=out)) return postprocess_add_(input, other, output, alpha) @implements(Tensor.add_) def add_( input: EdgeIndex, other: Union[int, Tensor], *, alpha: int = 1, ) -> Union[EdgeIndex, Tensor]: output = Tensor.__torch_function__( # Tensor.add_, (Tensor, ), (input, other), dict(alpha=alpha)) return postprocess_add_(input, other, output, alpha) def postprocess_sub_( input: EdgeIndex, other: Union[int, Tensor], out: Tensor, alpha: int = 1, ) -> Union[EdgeIndex, Tensor]: if out.dtype not in SUPPORTED_DTYPES: return out if out.dim() != 2 or out.size(0) != 2: return out output: EdgeIndex = out.as_subclass(EdgeIndex) if isinstance(other, int): size = maybe_sub(input._sparse_size, other, alpha) assert len(size) == 2 output._sparse_size = size output._sort_order = input._sort_order output._is_undirected = input.is_undirected output._T_perm = input._T_perm elif isinstance(other, Tensor) and other.numel() <= 1: size = maybe_sub(input._sparse_size, int(other), alpha) assert len(size) == 2 output._sparse_size = size output._sort_order = input._sort_order output._is_undirected = input.is_undirected output._T_perm = input._T_perm elif isinstance(other, Tensor) and other.size() == (2, 1): size = maybe_sub(input._sparse_size, other.view(-1).tolist(), alpha) assert len(size) == 2 output._sparse_size = size output._sort_order = input._sort_order output._T_perm = input._T_perm if torch.equal(other[0], other[1]): output._is_undirected = input.is_undirected return output @implements(torch.sub) @implements(Tensor.sub) def sub( input: EdgeIndex, other: Union[int, Tensor], *, alpha: int = 1, out: Optional[Tensor] = None, ) -> Union[EdgeIndex, Tensor]: output = Tensor.__torch_function__( # torch.sub, (Tensor, ), (input, other), dict(alpha=alpha, out=out)) return postprocess_sub_(input, other, output, alpha) @implements(Tensor.sub_) def sub_( input: EdgeIndex, other: Union[int, Tensor], *, alpha: int = 1, ) -> Union[EdgeIndex, Tensor]: output = Tensor.__torch_function__( # Tensor.sub_, (Tensor, ), (input, other), dict(alpha=alpha)) return postprocess_sub_(input, other, output, alpha) # Sparse-Dense Matrix Multiplication ########################################## def _torch_sparse_spmm( input: EdgeIndex, other: Tensor, value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Tensor: # `torch-sparse` still provides a faster sparse-dense matrix multiplication # code path on GPUs (after all these years...): assert torch_geometric.typing.WITH_TORCH_SPARSE reduce = PYG_REDUCE[reduce] if reduce in PYG_REDUCE else reduce # Optional arguments for backpropagation: colptr: Optional[Tensor] = None perm: Optional[Tensor] = None if not transpose: assert input.is_sorted_by_row (rowptr, col), _ = input.get_csr() row = input[0] if other.requires_grad and reduce in ['sum', 'mean']: (colptr, _), perm = input.get_csc() else: assert input.is_sorted_by_col (rowptr, col), _ = input.get_csc() row = input[1] if other.requires_grad and reduce in ['sum', 'mean']: (colptr, _), perm = input.get_csr() if reduce == 'sum': return torch.ops.torch_sparse.spmm_sum( # row, rowptr, col, value, colptr, perm, other) if reduce == 'mean': rowcount = rowptr.diff() if other.requires_grad else None return torch.ops.torch_sparse.spmm_mean( # row, rowptr, col, value, rowcount, colptr, perm, other) if reduce == 'min': return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)[0] if reduce == 'max': return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)[0] raise NotImplementedError class _TorchSPMM(torch.autograd.Function): @staticmethod def forward( ctx: Any, input: EdgeIndex, other: Tensor, value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Tensor: reduce = TORCH_REDUCE[reduce] if reduce in TORCH_REDUCE else reduce value = value.detach() if value is not None else value if other.requires_grad: other = other.detach() ctx.save_for_backward(input, value) ctx.reduce = reduce ctx.transpose = transpose if not transpose: assert input.is_sorted_by_row adj = input.to_sparse_csr(value) else: assert input.is_sorted_by_col adj = input.to_sparse_csc(value).t() if torch_geometric.typing.WITH_PT20 and not other.is_cuda: return torch.sparse.mm(adj, other, reduce) else: # pragma: no cover assert reduce == 'sum' return adj @ other @staticmethod def backward( ctx: Any, *grad_outputs: Any, ) -> Tuple[None, Optional[Tensor], None, None, None]: grad_out, = grad_outputs other_grad: Optional[Tensor] = None if ctx.needs_input_grad[1]: input, value = ctx.saved_tensors assert ctx.reduce == 'sum' if not ctx.transpose: if value is None and input.is_undirected: adj = input.to_sparse_csr(value) else: (colptr, row), perm = input.get_csc() if value is not None: value = value[perm] else: value = input._get_value() adj = torch.sparse_csr_tensor( crow_indices=colptr, col_indices=row, values=value, size=input.get_sparse_size()[::-1], device=input.device, ) else: if value is None and input.is_undirected: adj = input.to_sparse_csc(value).t() else: (rowptr, col), perm = input.get_csr() if value is not None: value = value[perm] else: value = input._get_value() adj = torch.sparse_csr_tensor( crow_indices=rowptr, col_indices=col, values=value, size=input.get_sparse_size()[::-1], device=input.device, ) other_grad = adj @ grad_out if ctx.needs_input_grad[2]: raise NotImplementedError("Gradient computation for 'value' not " "yet supported") return None, other_grad, None, None, None def _scatter_spmm( input: EdgeIndex, other: Tensor, value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Tensor: from torch_geometric.utils import scatter if not transpose: other_j = other[input[1]] index = input[0] else: other_j = other[input[0]] index = input[1] other_j = other_j * value.view(-1, 1) if value is not None else other_j return scatter(other_j, index, 0, dim_size=other.size(0), reduce=reduce) def _spmm( input: EdgeIndex, other: Tensor, value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Tensor: if reduce not in get_args(ReduceType): raise ValueError(f"`reduce='{reduce}'` is not a valid reduction") if not transpose and not input.is_sorted_by_row: cls_name = input.__class__.__name__ raise ValueError(f"'matmul(..., transpose=False)' requires " f"'{cls_name}' to be sorted by rows") if transpose and not input.is_sorted_by_col: cls_name = input.__class__.__name__ raise ValueError(f"'matmul(..., transpose=True)' requires " f"'{cls_name}' to be sorted by colums") if (torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling() and other.is_cuda): # pragma: no cover return _torch_sparse_spmm(input, other, value, reduce, transpose) if value is not None and value.requires_grad: if torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling(): return _torch_sparse_spmm(input, other, value, reduce, transpose) return _scatter_spmm(input, other, value, reduce, transpose) if reduce == 'sum' or reduce == 'add': return _TorchSPMM.apply(input, other, value, 'sum', transpose) if reduce == 'mean': out = _TorchSPMM.apply(input, other, value, 'sum', transpose) count = input.get_indptr().diff() return out / count.clamp_(min=1).to(out.dtype).view(-1, 1) if (torch_geometric.typing.WITH_PT20 and not other.is_cuda and not other.requires_grad): return _TorchSPMM.apply(input, other, value, reduce, transpose) if torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling(): return _torch_sparse_spmm(input, other, value, reduce, transpose) return _scatter_spmm(input, other, value, reduce, transpose) def matmul( input: EdgeIndex, other: Union[Tensor, EdgeIndex], input_value: Optional[Tensor] = None, other_value: Optional[Tensor] = None, reduce: ReduceType = 'sum', transpose: bool = False, ) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]: if not isinstance(other, EdgeIndex): if other_value is not None: raise ValueError("'other_value' not supported for sparse-dense " "matrix multiplication") return _spmm(input, other, input_value, reduce, transpose) if reduce not in ['sum', 'add']: raise NotImplementedError(f"`reduce='{reduce}'` not yet supported for " f"sparse-sparse matrix multiplication") transpose &= not input.is_undirected or input_value is not None if torch_geometric.typing.WITH_WINDOWS: # pragma: no cover sparse_input = input.to_sparse_coo(input_value) elif input.is_sorted_by_col: sparse_input = input.to_sparse_csc(input_value) else: sparse_input = input.to_sparse_csr(input_value) if transpose: sparse_input = sparse_input.t() if torch_geometric.typing.WITH_WINDOWS: # pragma: no cover other = other.to_sparse_coo(other_value) elif other.is_sorted_by_col: other = other.to_sparse_csc(other_value) else: other = other.to_sparse_csr(other_value) out = torch.matmul(sparse_input, other) rowptr: Optional[Tensor] = None if out.layout == torch.sparse_csr: rowptr = out.crow_indices().to(input.dtype) col = out.col_indices().to(input.dtype) edge_index = torch._convert_indices_from_csr_to_coo( rowptr, col, out_int32=rowptr.dtype != torch.int64) elif out.layout == torch.sparse_coo: # pragma: no cover out = out.coalesce() edge_index = out.indices() else: raise NotImplementedError edge_index = edge_index.as_subclass(EdgeIndex) edge_index._sort_order = SortOrder.ROW edge_index._sparse_size = (out.size(0), out.size(1)) edge_index._indptr = rowptr return edge_index, out.values() @implements(torch.mm) @implements(torch.matmul) @implements(Tensor.matmul) def _matmul1( input: EdgeIndex, other: Union[Tensor, EdgeIndex], ) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]: return matmul(input, other) @implements(torch.sparse.mm) def _matmul2( mat1: EdgeIndex, mat2: Union[Tensor, EdgeIndex], reduce: ReduceType = 'sum', ) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]: return matmul(mat1, mat2, reduce=reduce)