Source code for torch_geometric.utils.scatter

import torch
import torch_scatter


[docs]def scatter_(name, src, index, dim=0, dim_size=None): r"""Aggregates all values from the :attr:`src` tensor at the indices specified in the :attr:`index` tensor along the first dimension. If multiple indices reference the same location, their contributions are aggregated according to :attr:`name` (either :obj:`"add"`, :obj:`"mean"` or :obj:`"max"`). Args: name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). src (Tensor): The source tensor. index (LongTensor): The indices of elements to scatter. dim (int, optional): The axis along which to index. (default: :obj:`0`) dim_size (int, optional): Automatically create output tensor with size :attr:`dim_size` in the first dimension. If set to :attr:`None`, a minimal sized output tensor is returned. (default: :obj:`None`) :rtype: :class:`Tensor` """ assert name in ['add', 'mean', 'max'] if name == 'max': op = torch.finfo if torch.is_floating_point(src) else torch.iinfo fill_value = op(src.dtype).min else: fill_value = 0 op = getattr(torch_scatter, 'scatter_{}'.format(name)) out = op(src, index, dim, None, dim_size, fill_value) if isinstance(out, tuple): out = out[0] if name == 'max': out[out == fill_value] = 0 return out