Source code for torch_geometric.utils._softmax

from typing import Optional

from torch import Tensor

import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import scatter, segment
from torch_geometric.utils.num_nodes import maybe_num_nodes


[docs]def softmax( src: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None, dim: int = 0, ) -> Tensor: r"""Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. Args: src (Tensor): The source tensor. index (LongTensor, optional): The indices of elements for applying the softmax. (default: :obj:`None`) ptr (LongTensor, optional): If given, computes the softmax based on sorted inputs in CSR representation. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) dim (int, optional): The dimension in which to normalize. (default: :obj:`0`) :rtype: :class:`Tensor` Examples: >>> src = torch.tensor([1., 1., 1., 1.]) >>> index = torch.tensor([0, 0, 1, 2]) >>> ptr = torch.tensor([0, 2, 3, 4]) >>> softmax(src, index) tensor([0.5000, 0.5000, 1.0000, 1.0000]) >>> softmax(src, None, ptr) tensor([0.5000, 0.5000, 1.0000, 1.0000]) >>> src = torch.randn(4, 4) >>> ptr = torch.tensor([0, 4]) >>> softmax(src, index, dim=-1) tensor([[0.7404, 0.2596, 1.0000, 1.0000], [0.1702, 0.8298, 1.0000, 1.0000], [0.7607, 0.2393, 1.0000, 1.0000], [0.8062, 0.1938, 1.0000, 1.0000]]) """ if (ptr is not None and src.device.type == 'cpu' and torch_geometric.typing.WITH_SOFTMAX and not is_compiling()): # pragma: no cover return pyg_lib.ops.softmax_csr(src, ptr, dim) if (ptr is not None and (ptr.dim() == 1 or (ptr.dim() > 1 and index is None) or (torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()))): dim = dim + src.dim() if dim < 0 else dim size = ([1] * dim) + [-1] count = ptr[1:] - ptr[:-1] ptr = ptr.view(size) src_max = segment(src.detach(), ptr, reduce='max') src_max = src_max.repeat_interleave(count, dim=dim) out = (src - src_max).exp() out_sum = segment(out, ptr, reduce='sum') + 1e-16 out_sum = out_sum.repeat_interleave(count, dim=dim) elif index is not None: N = maybe_num_nodes(index, num_nodes) src_max = scatter(src.detach(), index, dim, dim_size=N, reduce='max') out = src - src_max.index_select(dim, index) out = out.exp() out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') + 1e-16 out_sum = out_sum.index_select(dim, index) else: raise NotImplementedError("'softmax' requires 'index' to be specified") return out / out_sum