Source code for torch_geometric.utils.softmax

from typing import Optional

from torch import Tensor
from torch_scatter import scatter, segment_csr, gather_csr

from .num_nodes import maybe_num_nodes

[docs]def softmax(src: Tensor, index: Tensor, ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None) -> Tensor: r"""Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. Args: src (Tensor): The source tensor. index (LongTensor): The indices of elements for applying the softmax. ptr (LongTensor, optional): If given, computes the softmax based on sorted inputs in CSR representation. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) :rtype: :class:`Tensor` """ if ptr is None: N = maybe_num_nodes(index, num_nodes) out = src - scatter(src, index, dim=0, dim_size=N, reduce='max')[index] out = out.exp() out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index] return out / (out_sum + 1e-16) else: out = src - gather_csr(segment_csr(src, ptr, reduce='max'), ptr) out = out.exp() out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) return out / (out_sum + 1e-16)