Source code for torch_geometric.nn.pool.topk_pool

from typing import Callable, Optional, Tuple, Union

import torch
from torch import Tensor

from import SelectTopK
from torch_geometric.utils.num_nodes import maybe_num_nodes

def filter_adj(
    edge_index: Tensor,
    edge_attr: Optional[Tensor],
    perm: Tensor,
    num_nodes: Optional[int] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    mask = perm.new_full((num_nodes, ), -1)
    i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
    mask[perm] = i

    row, col = edge_index[0], edge_index[1]
    row, col = mask[row], mask[col]
    mask = (row >= 0) & (col >= 0)
    row, col = row[mask], col[mask]

    if edge_attr is not None:
        edge_attr = edge_attr[mask]

    return torch.stack([row, col], dim=0), edge_attr

[docs]class TopKPooling(torch.nn.Module): r""":math:`\mathrm{top}_k` pooling operator from the `"Graph U-Nets" <>`_, `"Towards Sparse Hierarchical Graph Classifiers" <>`_ and `"Understanding Attention and Generalization in Graph Neural Networks" <>`_ papers. If :obj:`min_score` :math:`\tilde{\alpha}` is :obj:`None`, computes: .. math:: \mathbf{y} &= \sigma \left( \frac{\mathbf{X}\mathbf{p}}{\| \mathbf{p} \|} \right) \mathbf{i} &= \mathrm{top}_k(\mathbf{y}) \mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}} \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}} If :obj:`min_score` :math:`\tilde{\alpha}` is a value in :obj:`[0, 1]`, computes: .. math:: \mathbf{y} &= \mathrm{softmax}(\mathbf{X}\mathbf{p}) \mathbf{i} &= \mathbf{y}_i > \tilde{\alpha} \mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}} \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}, where nodes are dropped based on a learnable projection score :math:`\mathbf{p}`. Args: in_channels (int): Size of each input sample. ratio (float or int): The graph pooling ratio, which is used to compute :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`, or the value of :math:`k` itself, depending on whether the type of :obj:`ratio` is :obj:`float` or :obj:`int`. This value is ignored if :obj:`min_score` is not :obj:`None`. (default: :obj:`0.5`) min_score (float, optional): Minimal node score :math:`\tilde{\alpha}` which is used to compute indices of pooled nodes :math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`. When this value is not :obj:`None`, the :obj:`ratio` argument is ignored. (default: :obj:`None`) multiplier (float, optional): Coefficient by which features gets multiplied after pooling. This can be useful for large graphs and when :obj:`min_score` is used. (default: :obj:`1`) nonlinearity (str or callable, optional): The non-linearity :math:`\sigma`. (default: :obj:`"tanh"`) """ def __init__( self, in_channels: int, ratio: Union[int, float] = 0.5, min_score: Optional[float] = None, multiplier: float = 1., nonlinearity: Union[str, Callable] = 'tanh', ): super().__init__() self.in_channels = in_channels self.ratio = ratio self.min_score = min_score self.multiplier = multiplier = SelectTopK(in_channels, ratio, min_score, nonlinearity) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module."""
[docs] def forward( self, x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None, batch: Optional[Tensor] = None, attn: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: r""" Args: x (torch.Tensor): The node feature matrix. edge_index (torch.Tensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) attn (torch.Tensor, optional): Optional node-level matrix to use for computing attention scores instead of using the node feature matrix :obj:`x`. (default: :obj:`None`) """ if batch is None: batch = edge_index.new_zeros(x.size(0)) attn = x if attn is None else attn select_output =, batch) perm = select_output.node_index score = select_output.weight assert score is not None x = x[perm] * score.view(-1, 1) x = self.multiplier * x if self.multiplier != 1 else x batch = batch[perm] edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=select_output.num_nodes) return x, edge_index, edge_attr, batch, perm, score
def __repr__(self) -> str: if self.min_score is None: ratio = f'ratio={self.ratio}' else: ratio = f'min_score={self.min_score}' return (f'{self.__class__.__name__}({self.in_channels}, {ratio}, ' f'multiplier={self.multiplier})')