class QuantileAggregation(q: Union[float, List[float]], interpolation: str = 'linear', fill_value: float = 0.0)[source]

Bases: Aggregation

An aggregation operator that returns the feature-wise \(q\)-th quantile of a set \(\mathcal{X}\). That is, for every feature \(d\), it computes

\[\begin{split}{\mathrm{Q}_q(\mathcal{X})}_d = \begin{cases} x_{\pi_i,d} & i = q \cdot n, \\ f(x_{\pi_i,d}, x_{\pi_{i+1},d}) & i < q \cdot n < i + 1,\\ \end{cases}\end{split}\]

where \(x_{\pi_1,d} \le \dots \le x_{\pi_i,d} \le \dots \le x_{\pi_n,d}\) and \(f(a, b)\) is an interpolation function defined by interpolation.

  • q (float or list) – The quantile value(s) \(q\). Can be a scalar or a list of scalars in the range \([0, 1]\). If more than a quantile is passed, the results are concatenated.

  • interpolation (str) –

    Interpolation method applied if the quantile point \(q\cdot n\) lies between two values \(a \le b\). Can be one of the following:

    • "lower": Returns the one with lowest value.

    • "higher": Returns the one with highest value.

    • "midpoint": Returns the average of the two values.

    • "nearest": Returns the one whose index is nearest to the quantile point.

    • "linear": Returns a linear combination of the two elements, defined as \(f(a, b) = a + (b - a)\cdot(q\cdot n - i)\).

    (default: "linear")

  • fill_value (float, optional) – The default value in the case no entry is found for a given index (default: 0.0).

forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor[source]
  • x (torch.Tensor) – The source tensor.

  • index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of index or ptr must be defined. (default: None)

  • ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of index or ptr must be defined. (default: None)

  • dim_size (int, optional) – The size of the output tensor at dimension dim after aggregation. (default: None)

  • dim (int, optional) – The dimension in which to aggregate. (default: -2)