Source code for torch_geometric.utils.spmm

from typing import Union

import torch
from torch import Tensor
from torch_sparse import SparseTensor, matmul

from .sparse import is_torch_sparse_tensor


@torch.jit._overload
def spmm(src, other, reduce):
    # type: (Tensor, Tensor, str) -> Tensor
    pass


@torch.jit._overload
def spmm(src, other, reduce):
    # type: (SparseTensor, Tensor, str) -> Tensor
    pass


[docs]def spmm( src: Union[SparseTensor, Tensor], other: Tensor, reduce: str = "sum", ) -> Tensor: """Matrix product of sparse matrix with dense matrix. Args: src (Tensor or torch_sparse.SparseTensor]): The input sparse matrix, either a :class:`torch_sparse.SparseTensor` or a :class:`torch.sparse.Tensor`. other (Tensor): The input dense matrix. reduce (str, optional): The reduce operation to use (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`). (default: :obj:`"sum"`) :rtype: :class:`Tensor` """ assert reduce in ['sum', 'add', 'mean', 'min', 'max'] if isinstance(src, SparseTensor): return matmul(src, other, reduce) if not is_torch_sparse_tensor(src): raise ValueError("`src` must be a `torch_sparse.SparseTensor` " f"or a `torch.sparse.Tensor` (got {type(src)}).") if reduce in ['sum', 'add']: return torch.sparse.mm(src, other) # TODO: Support `mean` reduction for PyTorch SparseTensor raise ValueError(f"`{reduce}` reduction is not supported for " f"`torch.sparse.Tensor`.")