from typing import Optional
from torch import Tensor
import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import scatter, segment
from torch_geometric.utils.num_nodes import maybe_num_nodes
[docs]def softmax(
src: Tensor,
index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None,
num_nodes: Optional[int] = None,
dim: int = 0,
) -> Tensor:
r"""Computes a sparsely evaluated softmax.
Given a value tensor :attr:`src`, this function first groups the values
along the first dimension based on the indices specified in :attr:`index`,
and then proceeds to compute the softmax individually for each group.
Args:
src (Tensor): The source tensor.
index (LongTensor, optional): The indices of elements for applying the
softmax. (default: :obj:`None`)
ptr (LongTensor, optional): If given, computes the softmax based on
sorted inputs in CSR representation. (default: :obj:`None`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
dim (int, optional): The dimension in which to normalize.
(default: :obj:`0`)
:rtype: :class:`Tensor`
Examples:
>>> src = torch.tensor([1., 1., 1., 1.])
>>> index = torch.tensor([0, 0, 1, 2])
>>> ptr = torch.tensor([0, 2, 3, 4])
>>> softmax(src, index)
tensor([0.5000, 0.5000, 1.0000, 1.0000])
>>> softmax(src, None, ptr)
tensor([0.5000, 0.5000, 1.0000, 1.0000])
>>> src = torch.randn(4, 4)
>>> ptr = torch.tensor([0, 4])
>>> softmax(src, index, dim=-1)
tensor([[0.7404, 0.2596, 1.0000, 1.0000],
[0.1702, 0.8298, 1.0000, 1.0000],
[0.7607, 0.2393, 1.0000, 1.0000],
[0.8062, 0.1938, 1.0000, 1.0000]])
"""
if (ptr is not None and src.device.type == 'cpu'
and torch_geometric.typing.WITH_SOFTMAX
and not is_compiling()): # pragma: no cover
return pyg_lib.ops.softmax_csr(src, ptr, dim)
if (ptr is not None and
(ptr.dim() == 1 or (ptr.dim() > 1 and index is None) or
(torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()))):
dim = dim + src.dim() if dim < 0 else dim
size = ([1] * dim) + [-1]
count = ptr[1:] - ptr[:-1]
ptr = ptr.view(size)
src_max = segment(src.detach(), ptr, reduce='max')
src_max = src_max.repeat_interleave(count, dim=dim)
out = (src - src_max).exp()
out_sum = segment(out, ptr, reduce='sum') + 1e-16
out_sum = out_sum.repeat_interleave(count, dim=dim)
elif index is not None:
N = maybe_num_nodes(index, num_nodes)
src_max = scatter(src.detach(), index, dim, dim_size=N, reduce='max')
out = src - src_max.index_select(dim, index)
out = out.exp()
out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') + 1e-16
out_sum = out_sum.index_select(dim, index)
else:
raise NotImplementedError("'softmax' requires 'index' to be specified")
return out / out_sum