# torch_geometric.nn.pool.TopKPooling

class TopKPooling(in_channels: int, ratio: = 0.5, min_score: = None, multiplier: float = 1.0, nonlinearity: = 'tanh')[source]

Bases: Module

$$\mathrm{top}_k$$ pooling operator from the “Graph U-Nets”, “Towards Sparse Hierarchical Graph Classifiers” and “Understanding Attention and Generalization in Graph Neural Networks” papers.

If min_score $$\tilde{\alpha}$$ is None, computes:

\begin{align}\begin{aligned}\mathbf{y} &= \sigma \left( \frac{\mathbf{X}\mathbf{p}}{\| \mathbf{p} \|} \right)\\\mathbf{i} &= \mathrm{top}_k(\mathbf{y})\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}\end{aligned}\end{align}

If min_score $$\tilde{\alpha}$$ is a value in [0, 1], computes:

\begin{align}\begin{aligned}\mathbf{y} &= \mathrm{softmax}(\mathbf{X}\mathbf{p})\\\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},\end{aligned}\end{align}

where nodes are dropped based on a learnable projection score $$\mathbf{p}$$.

Parameters:
• in_channels (int) – Size of each input sample.

• ratio (float or int) – The graph pooling ratio, which is used to compute $$k = \lceil \mathrm{ratio} \cdot N \rceil$$, or the value of $$k$$ itself, depending on whether the type of ratio is float or int. This value is ignored if min_score is not None. (default: 0.5)

• min_score (float, optional) – Minimal node score $$\tilde{\alpha}$$ which is used to compute indices of pooled nodes $$\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}$$. When this value is not None, the ratio argument is ignored. (default: None)

• multiplier (float, optional) – Coefficient by which features gets multiplied after pooling. This can be useful for large graphs and when min_score is used. (default: 1)

• nonlinearity (str or callable, optional) – The non-linearity $$\sigma$$. (default: "tanh")

reset_parameters()[source]

Resets all learnable parameters of the module.

forward(x: Tensor, edge_index: Tensor, edge_attr: = None, batch: = None, attn: = None) [source]
Parameters: