Source code for torch_geometric.nn.pool.pan_pool

from typing import Callable, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.pool.connect import FilterEdges
from torch_geometric.nn.pool.select import SelectTopK
from torch_geometric.typing import OptTensor, SparseTensor
from torch_geometric.utils import scatter


[docs]class PANPooling(torch.nn.Module): r"""The path integral based pooling operator from the `"Path Integral Based Convolution and Pooling for Graph Neural Networks" <https://arxiv.org/abs/2006.16811>`_ paper. PAN pooling performs top-:math:`k` pooling where global node importance is measured based on node features and the MET matrix: .. math:: {\rm score} = \beta_1 \mathbf{X} \cdot \mathbf{p} + \beta_2 {\rm deg}(\mathbf{M}) Args: in_channels (int): Size of each input sample. ratio (float): Graph pooling ratio, which is used to compute :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`. This value is ignored if min_score is not None. (default: :obj:`0.5`) min_score (float, optional): Minimal node score :math:`\tilde{\alpha}` which is used to compute indices of pooled nodes :math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`. When this value is not :obj:`None`, the :obj:`ratio` argument is ignored. (default: :obj:`None`) multiplier (float, optional): Coefficient by which features gets multiplied after pooling. This can be useful for large graphs and when :obj:`min_score` is used. (default: :obj:`1.0`) nonlinearity (str or callable, optional): The non-linearity to use. (default: :obj:`"tanh"`) """ def __init__( self, in_channels: int, ratio: float = 0.5, min_score: Optional[float] = None, multiplier: float = 1.0, nonlinearity: Union[str, Callable] = 'tanh', ): super().__init__() self.in_channels = in_channels self.ratio = ratio self.min_score = min_score self.multiplier = multiplier self.p = Parameter(torch.empty(in_channels)) self.beta = Parameter(torch.empty(2)) self.select = SelectTopK(1, ratio, min_score, nonlinearity) self.connect = FilterEdges() self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.p.data.fill_(1) self.beta.data.fill_(0.5) self.select.reset_parameters()
[docs] def forward( self, x: Tensor, M: SparseTensor, batch: OptTensor = None, ) -> Tuple[Tensor, Tensor, Tensor, OptTensor, Tensor, Tensor]: r"""Forward pass. Args: x (torch.Tensor): The node feature matrix. M (SparseTensor): The MET matrix :math:`\mathbf{M}`. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) """ if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) row, col, edge_weight = M.coo() assert edge_weight is not None score1 = (x * self.p).sum(dim=-1) score2 = scatter(edge_weight, col, 0, dim_size=x.size(0), reduce='sum') score = self.beta[0] * score1 + self.beta[1] * score2 select_out = self.select(score, batch) perm = select_out.node_index score = select_out.weight assert score is not None x = x[perm] * score.view(-1, 1) x = self.multiplier * x if self.multiplier != 1 else x edge_index = torch.stack([col, row], dim=0) connect_out = self.connect(select_out, edge_index, edge_weight, batch) edge_weight = connect_out.edge_attr assert edge_weight is not None return (x, connect_out.edge_index, edge_weight, connect_out.batch, perm, score)
def __repr__(self) -> str: if self.min_score is None: ratio = f'ratio={self.ratio}' else: ratio = f'min_score={self.min_score}' return (f'{self.__class__.__name__}({self.in_channels}, {ratio}, ' f'multiplier={self.multiplier})')