torch_geometric.nn.aggr.QuantileAggregation
- class QuantileAggregation(q: Union[float, List[float]], interpolation: str = 'linear', fill_value: float = 0.0)[source]
Bases:
Aggregation
An aggregation operator that returns the feature-wise \(q\)-th quantile of a set \(\mathcal{X}\).
That is, for every feature \(d\), it computes
\[\begin{split}{\mathrm{Q}_q(\mathcal{X})}_d = \begin{cases} x_{\pi_i,d} & i = q \cdot n, \\ f(x_{\pi_i,d}, x_{\pi_{i+1},d}) & i < q \cdot n < i + 1,\\ \end{cases}\end{split}\]where \(x_{\pi_1,d} \le \dots \le x_{\pi_i,d} \le \dots \le x_{\pi_n,d}\) and \(f(a, b)\) is an interpolation function defined by
interpolation
.- Parameters:
q (float or list) – The quantile value(s) \(q\). Can be a scalar or a list of scalars in the range \([0, 1]\). If more than a quantile is passed, the results are concatenated.
interpolation (str) –
Interpolation method applied if the quantile point \(q\cdot n\) lies between two values \(a \le b\). Can be one of the following:
"lower"
: Returns the one with lowest value."higher"
: Returns the one with highest value."midpoint"
: Returns the average of the two values."nearest"
: Returns the one whose index is nearest to the quantile point."linear"
: Returns a linear combination of the two elements, defined as \(f(a, b) = a + (b - a)\cdot(q\cdot n - i)\).
(default:
"linear"
)fill_value (float, optional) – The default value in the case no entry is found for a given index (default:
0.0
).
- forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor [source]
Forward pass.
- Parameters:
x (torch.Tensor) – The source tensor.
index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of
index
orptr
must be defined. (default:None
)ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of
index
orptr
must be defined. (default:None
)dim_size (int, optional) – The size of the output tensor at dimension
dim
after aggregation. (default:None
)dim (int, optional) – The dimension in which to aggregate. (default:
-2
)max_num_elements – (int, optional): The maximum number of elements within a single aggregation group. (default:
None
)