from typing import Optional
from torch import Tensor
from torch_scatter import scatter, segment_csr, gather_csr
from .num_nodes import maybe_num_nodes
[docs]def softmax(src: Tensor, index: Tensor, ptr: Optional[Tensor] = None,
num_nodes: Optional[int] = None) -> Tensor:
r"""Computes a sparsely evaluated softmax.
Given a value tensor :attr:`src`, this function first groups the values
along the first dimension based on the indices specified in :attr:`index`,
and then proceeds to compute the softmax individually for each group.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements for applying the softmax.
ptr (LongTensor, optional): If given, computes the softmax based on
sorted inputs in CSR representation. (default: :obj:`None`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
if ptr is None:
N = maybe_num_nodes(index, num_nodes)
out = src - scatter(src, index, dim=0, dim_size=N, reduce='max')[index]
out = out.exp()
out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index]
return out / (out_sum + 1e-16)
else:
out = src - gather_csr(segment_csr(src, ptr, reduce='max'), ptr)
out = out.exp()
out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr)
return out / (out_sum + 1e-16)