from typing import List
import numpy as np
import torch
from torch import Tensor
import torch_geometric.typing
[docs]def lexsort(
keys: List[Tensor],
dim: int = -1,
descending: bool = False,
) -> Tensor:
r"""Performs an indirect stable sort using a sequence of keys.
Given multiple sorting keys, returns an array of integer indices that
describe their sort order.
The last key in the sequence is used for the primary sort order, the
second-to-last key for the secondary sort order, and so on.
Args:
keys ([torch.Tensor]): The :math:`k` different columns to be sorted.
The last key is the primary sort key.
dim (int, optional): The dimension to sort along. (default: :obj:`-1`)
descending (bool, optional): Controls the sorting order (ascending or
descending). (default: :obj:`False`)
"""
assert len(keys) >= 1
if not torch_geometric.typing.WITH_PT113:
keys = [k.neg() for k in keys] if descending else keys
out = np.lexsort([k.detach().cpu().numpy() for k in keys], axis=dim)
return torch.from_numpy(out).to(keys[0].device)
out = keys[0].argsort(dim=dim, descending=descending, stable=True)
for k in keys[1:]:
index = k.gather(dim, out)
index = index.argsort(dim=dim, descending=descending, stable=True)
out = out.gather(dim, index)
return out