Source code for torch_geometric.utils.spmm

import warnings

import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.typing import Adj, SparseTensor, torch_sparse
from torch_geometric.utils import is_torch_sparse_tensor, scatter


@torch.jit._overload
def spmm(src, other, reduce):
    # type: (Tensor, Tensor, str) -> Tensor
    pass


@torch.jit._overload
def spmm(src, other, reduce):
    # type: (SparseTensor, Tensor, str) -> Tensor
    pass


[docs]def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor: """Matrix product of sparse matrix with dense matrix. Args: src (Tensor or torch_sparse.SparseTensor): The input sparse matrix, either a :pyg:`PyG` :class:`torch_sparse.SparseTensor` or a :pytorch:`PyTorch` :class:`torch.sparse.Tensor`. other (Tensor): The input dense matrix. reduce (str, optional): The reduce operation to use (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`). (default: :obj:`"sum"`) :rtype: :class:`Tensor` """ reduce = 'sum' if reduce == 'add' else reduce if reduce not in ['sum', 'mean', 'min', 'max']: raise ValueError(f"`reduce` argument '{reduce}' not supported") if isinstance(src, SparseTensor): if (torch_geometric.typing.WITH_PT2 and other.dim() == 2 and not src.is_cuda()): # Use optimized PyTorch `torch.sparse.mm` path: csr = src.to_torch_sparse_csr_tensor() return torch.sparse.mm(csr, other, reduce) return torch_sparse.matmul(src, other, reduce) if not is_torch_sparse_tensor(src): raise ValueError("`src` must be a `torch_sparse.SparseTensor` " f"or a `torch.sparse.Tensor` (got {type(src)}).") # `torch.sparse.mm` only supports reductions on CPU for PyTorch>=2.0. # This will currently throw on error for CUDA tensors. if torch_geometric.typing.WITH_PT2: if src.is_cuda and (reduce == 'min' or reduce == 'max'): raise NotImplementedError(f"`{reduce}` reduction is not yet " f"supported for 'torch.sparse.Tensor' " f"on device '{src.device}'") # Always convert COO to CSR for more efficient processing: if src.layout == torch.sparse_coo: warnings.warn(f"Converting sparse tensor to CSR format for more " f"efficient processing. Consider converting your " f"sparse tensor to CSR format beforehand to avoid " f"repeated conversion (got '{src.layout}')") src = src.to_sparse_csr() # Warn in case of CSC format without gradient computation: if src.layout == torch.sparse_csc and not other.requires_grad: warnings.warn(f"Converting sparse tensor to CSR format for more " f"efficient processing. Consider converting your " f"sparse tensor to CSR format beforehand to avoid " f"repeated conversion (got '{src.layout}')") # Use the default code path for `sum` reduction (works on CPU/GPU): if reduce == 'sum': return torch.sparse.mm(src, other) # Use the default code path with custom reduction (works on CPU): if src.layout == torch.sparse_csr and not src.is_cuda: return torch.sparse.mm(src, other, reduce) # Simulate `mean` reduction by dividing by degree: if reduce == 'mean': if src.layout == torch.sparse_csr: ptr = src.crow_indices() deg = ptr[1:] - ptr[:-1] else: assert src.layout == torch.sparse_csc deg = scatter(torch.ones_like(src.values()), src.row_indices(), dim=0, dim_size=src.size(0), reduce='sum') return torch.sparse.mm(src, other) / deg.view(-1, 1).clamp_(min=1) # TODO The `torch.sparse.mm` code path with the `reduce` argument does # not yet support CSC :( if src.layout == torch.sparse_csc: warnings.warn(f"Converting sparse tensor to CSR format for more " f"efficient processing. Consider converting your " f"sparse tensor to CSR format beforehand to avoid " f"repeated conversion (got '{src.layout}')") src = src.to_sparse_csr() return torch.sparse.mm(src, other, reduce) # pragma: no cover # PyTorch < 2.0 only supports sparse COO format: if reduce == 'sum': return torch.sparse.mm(src, other) elif reduce == 'mean': if src.layout == torch.sparse_csr: ptr = src.crow_indices() deg = ptr[1:] - ptr[:-1] elif src.layout == torch.sparse_csc: assert src.layout == torch.sparse_csc deg = scatter(torch.ones_like(src.values()), src.row_indices(), dim=0, dim_size=src.size(0), reduce='sum') else: assert src.layout == torch.sparse_coo src = src.coalesce() deg = scatter(torch.ones_like(src.values()), src.indices(), dim=0, dim_size=src.size(0), reduce='sum') return torch.sparse.mm(src, other) / deg.view(-1, 1).clamp_(min=1) raise ValueError(f"`{reduce}` reduction is not supported for " f"'torch.sparse.Tensor' on device '{src.device}'")