Source code for torch_geometric.utils.scatter

import torch_scatter


[docs]def scatter_(name, src, index, dim=0, dim_size=None): r"""Aggregates all values from the :attr:`src` tensor at the indices specified in the :attr:`index` tensor along the first dimension. If multiple indices reference the same location, their contributions are aggregated according to :attr:`name` (either :obj:`"add"`, :obj:`"mean"` or :obj:`"max"`). Args: name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`). src (Tensor): The source tensor. index (LongTensor): The indices of elements to scatter. dim (int, optional): The axis along which to index. (default: :obj:`0`) dim_size (int, optional): Automatically create output tensor with size :attr:`dim_size` in the first dimension. If set to :attr:`None`, a minimal sized output tensor is returned. (default: :obj:`None`) :rtype: :class:`Tensor` """ assert name in ['add', 'mean', 'min', 'max'] op = getattr(torch_scatter, 'scatter_{}'.format(name)) out = op(src, index, dim, None, dim_size) out = out[0] if isinstance(out, tuple) else out if name == 'max': out[out < -10000] = 0 elif name == 'min': out[out > 10000] = 0 return out