from typing import List, Optional, Union
import torch
from torch import Tensor
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.utils import cumsum
[docs]class QuantileAggregation(Aggregation):
r"""An aggregation operator that returns the feature-wise :math:`q`-th
quantile of a set :math:`\mathcal{X}`.
That is, for every feature :math:`d`, it computes
.. math::
{\mathrm{Q}_q(\mathcal{X})}_d = \begin{cases}
x_{\pi_i,d} & i = q \cdot n, \\
f(x_{\pi_i,d}, x_{\pi_{i+1},d}) & i < q \cdot n < i + 1,\\
\end{cases}
where :math:`x_{\pi_1,d} \le \dots \le x_{\pi_i,d} \le \dots \le
x_{\pi_n,d}` and :math:`f(a, b)` is an interpolation
function defined by :obj:`interpolation`.
Args:
q (float or list): The quantile value(s) :math:`q`. Can be a scalar or
a list of scalars in the range :math:`[0, 1]`. If more than a
quantile is passed, the results are concatenated.
interpolation (str): Interpolation method applied if the quantile point
:math:`q\cdot n` lies between two values
:math:`a \le b`. Can be one of the following:
* :obj:`"lower"`: Returns the one with lowest value.
* :obj:`"higher"`: Returns the one with highest value.
* :obj:`"midpoint"`: Returns the average of the two values.
* :obj:`"nearest"`: Returns the one whose index is nearest to the
quantile point.
* :obj:`"linear"`: Returns a linear combination of the two
elements, defined as
:math:`f(a, b) = a + (b - a)\cdot(q\cdot n - i)`.
(default: :obj:`"linear"`)
fill_value (float, optional): The default value in the case no entry is
found for a given index (default: :obj:`0.0`).
"""
interpolations = {'linear', 'lower', 'higher', 'nearest', 'midpoint'}
def __init__(self, q: Union[float, List[float]],
interpolation: str = 'linear', fill_value: float = 0.0):
super().__init__()
qs = [q] if not isinstance(q, (list, tuple)) else q
if len(qs) == 0:
raise ValueError("Provide at least one quantile value for `q`.")
if not all(0. <= quantile <= 1. for quantile in qs):
raise ValueError("`q` must be in the range [0, 1].")
if interpolation not in self.interpolations:
raise ValueError(f"Invalid interpolation method "
f"got ('{interpolation}')")
self._q = q
self.register_buffer('q', torch.tensor(qs).view(-1, 1))
self.interpolation = interpolation
self.fill_value = fill_value
[docs] def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
dim = x.dim() + dim if dim < 0 else dim
self.assert_index_present(index)
assert index is not None # Required for TorchScript.
count = torch.bincount(index, minlength=dim_size or 0)
ptr = cumsum(count)[:-1]
# In case there exists dangling indices (`dim_size > index.max()`), we
# need to clamp them to prevent out-of-bound issues:
if dim_size is not None:
ptr = ptr.clamp(max=x.size(dim) - 1)
q_point = self.q * (count - 1) + ptr
q_point = q_point.t().reshape(-1)
shape = [1] * x.dim()
shape[dim] = -1
index = index.view(shape).expand_as(x)
# Two sorts: the first one on the value,
# the second (stable) on the indices:
x, x_perm = torch.sort(x, dim=dim)
index = index.take_along_dim(x_perm, dim=dim)
index, index_perm = torch.sort(index, dim=dim, stable=True)
x = x.take_along_dim(index_perm, dim=dim)
# Compute the quantile interpolations:
if self.interpolation == 'lower':
quantile = x.index_select(dim, q_point.floor().long())
elif self.interpolation == 'higher':
quantile = x.index_select(dim, q_point.ceil().long())
elif self.interpolation == 'nearest':
quantile = x.index_select(dim, q_point.round().long())
else:
l_quant = x.index_select(dim, q_point.floor().long())
r_quant = x.index_select(dim, q_point.ceil().long())
if self.interpolation == 'linear':
q_frac = q_point.frac().view(shape)
quantile = l_quant + (r_quant - l_quant) * q_frac
else: # 'midpoint'
quantile = 0.5 * l_quant + 0.5 * r_quant
# If the number of elements is zero, fill with pre-defined value:
repeats = self.q.numel()
mask = (count == 0).repeat_interleave(
repeats, output_size=repeats * count.numel()).view(shape)
out = quantile.masked_fill(mask, self.fill_value)
if self.q.numel() > 1:
shape = list(out.shape)
shape = (shape[:dim] + [shape[dim] // self.q.numel(), -1] +
shape[dim + 2:])
out = out.view(shape)
return out
def __repr__(self) -> str:
return (f'{self.__class__.__name__}(q={self._q})')