import torch
import torch_scatter
[docs]def scatter_(name, src, index, dim=0, dim_size=None):
r"""Aggregates all values from the :attr:`src` tensor at the indices
specified in the :attr:`index` tensor along the first dimension.
If multiple indices reference the same location, their contributions
are aggregated according to :attr:`name` (either :obj:`"add"`,
:obj:`"mean"` or :obj:`"max"`).
Args:
name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`,
:obj:`"max"`).
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index. (default: :obj:`0`)
dim_size (int, optional): Automatically create output tensor with size
:attr:`dim_size` in the first dimension. If set to :attr:`None`, a
minimal sized output tensor is returned. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
assert name in ['add', 'mean', 'max']
if name == 'max':
op = torch.finfo if torch.is_floating_point(src) else torch.iinfo
fill_value = op(src.dtype).min
else:
fill_value = 0
op = getattr(torch_scatter, 'scatter_{}'.format(name))
out = op(src, index, dim, None, dim_size, fill_value)
if isinstance(out, tuple):
out = out[0]
if name == 'max':
out[out == fill_value] = 0
return out