# Source code for torch_geometric.nn.pool.topk_pool

from typing import Callable, Optional, Tuple, Union

import torch
from torch import Tensor

from torch_geometric.nn.pool.connect import FilterEdges
from torch_geometric.nn.pool.select import SelectTopK
from torch_geometric.typing import OptTensor

[docs]class TopKPooling(torch.nn.Module):
r""":math:\mathrm{top}_k pooling operator from the "Graph U-Nets"
<https://arxiv.org/abs/1905.05178>_, "Towards Sparse
Hierarchical Graph Classifiers" <https://arxiv.org/abs/1811.01287>_
and "Understanding Attention and Generalization in Graph Neural
Networks" <https://arxiv.org/abs/1905.02850>_ papers.

If :obj:min_score :math:\tilde{\alpha} is :obj:None, computes:

.. math::
\mathbf{y} &= \sigma \left( \frac{\mathbf{X}\mathbf{p}}{\|
\mathbf{p} \|} \right)

\mathbf{i} &= \mathrm{top}_k(\mathbf{y})

\mathbf{X}^{\prime} &= (\mathbf{X} \odot
\mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}

\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}

If :obj:min_score :math:\tilde{\alpha} is a value in :obj:[0, 1],
computes:

.. math::
\mathbf{y} &= \mathrm{softmax}(\mathbf{X}\mathbf{p})

\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}

\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}

\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},

where nodes are dropped based on a learnable projection score
:math:\mathbf{p}.

Args:
in_channels (int): Size of each input sample.
ratio (float or int): The graph pooling ratio, which is used to compute
:math:k = \lceil \mathrm{ratio} \cdot N \rceil, or the value
of :math:k itself, depending on whether the type of :obj:ratio
is :obj:float or :obj:int.
This value is ignored if :obj:min_score is not :obj:None.
(default: :obj:0.5)
min_score (float, optional): Minimal node score :math:\tilde{\alpha}
which is used to compute indices of pooled nodes
:math:\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}.
When this value is not :obj:None, the :obj:ratio argument is
ignored. (default: :obj:None)
multiplier (float, optional): Coefficient by which features gets
multiplied after pooling. This can be useful for large graphs and
when :obj:min_score is used. (default: :obj:1)
nonlinearity (str or callable, optional): The non-linearity
:math:\sigma. (default: :obj:"tanh")
"""
def __init__(
self,
in_channels: int,
ratio: Union[int, float] = 0.5,
min_score: Optional[float] = None,
multiplier: float = 1.,
nonlinearity: Union[str, Callable] = 'tanh',
):
super().__init__()

self.in_channels = in_channels
self.ratio = ratio
self.min_score = min_score
self.multiplier = multiplier

self.select = SelectTopK(in_channels, ratio, min_score, nonlinearity)
self.connect = FilterEdges()

self.reset_parameters()

[docs]    def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.select.reset_parameters()

[docs]    def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
batch: Optional[Tensor] = None,
attn: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, OptTensor, OptTensor, Tensor, Tensor]:
r"""Forward pass.

Args:
x (torch.Tensor): The node feature matrix.
edge_index (torch.Tensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:None)
batch (torch.Tensor, optional): The batch vector
:math:\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N, which assigns
each node to a specific example. (default: :obj:None)
attn (torch.Tensor, optional): Optional node-level matrix to use
for computing attention scores instead of using the node
feature matrix :obj:x. (default: :obj:None)
"""
if batch is None:
batch = edge_index.new_zeros(x.size(0))

attn = x if attn is None else attn
select_out = self.select(attn, batch)

perm = select_out.node_index
score = select_out.weight
assert score is not None

x = x[perm] * score.view(-1, 1)
x = self.multiplier * x if self.multiplier != 1 else x

connect_out = self.connect(select_out, edge_index, edge_attr, batch)

return (x, connect_out.edge_index, connect_out.edge_attr,
connect_out.batch, perm, score)

def __repr__(self) -> str:
if self.min_score is None:
ratio = f'ratio={self.ratio}'
else:
ratio = f'min_score={self.min_score}'

return (f'{self.__class__.__name__}({self.in_channels}, {ratio}, '
f'multiplier={self.multiplier})')