from torch_geometric.typing import OptTensor
import torch
from torch import Tensor
from torch_sparse import SparseTensor
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_geometric.utils import softmax
from .topk_pool import topk, filter_adj
[docs]class PANPooling(torch.nn.Module):
r"""The path integral based pooling operator from the
`"Path Integral Based Convolution and Pooling for Graph Neural Networks"
<https://arxiv.org/abs/2006.16811>`_ paper.
PAN pooling performs top-:math:`k` pooling where global node importance is
measured based on node features and the MET matrix:
.. math::
{\rm score} = \beta_1 \mathbf{X} \cdot \mathbf{p} + \beta_2
{\rm deg}(M)
Args:
in_channels (int): Size of each input sample.
ratio (float): Graph pooling ratio, which is used to compute
:math:`k = \lceil \mathrm{ratio} \cdot N \rceil`.
This value is ignored if min_score is not None.
(default: :obj:`0.5`)
min_score (float, optional): Minimal node score :math:`\tilde{\alpha}`
which is used to compute indices of pooled nodes
:math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`.
When this value is not :obj:`None`, the :obj:`ratio` argument is
ignored. (default: :obj:`None`)
multiplier (float, optional): Coefficient by which features gets
multiplied after pooling. This can be useful for large graphs and
when :obj:`min_score` is used. (default: :obj:`1.0`)
nonlinearity (torch.nn.functional, optional): The nonlinearity to use.
(default: :obj:`torch.tanh`)
"""
def __init__(self, in_channels, ratio=0.5, min_score=None, multiplier=1.0,
nonlinearity=torch.tanh):
super().__init__()
self.in_channels = in_channels
self.ratio = ratio
self.min_score = min_score
self.multiplier = multiplier
self.nonlinearity = nonlinearity
self.p = Parameter(torch.Tensor(in_channels))
self.beta = Parameter(torch.Tensor(2))
self.reset_parameters()
[docs] def reset_parameters(self):
self.p.data.fill_(1)
self.beta.data.fill_(0.5)
[docs] def forward(self, x: Tensor, M: SparseTensor, batch: OptTensor = None):
""""""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)
row, col, edge_weight = M.coo()
score1 = (x * self.p).sum(dim=-1)
score2 = scatter_add(edge_weight, col, dim=0, dim_size=x.size(0))
score = self.beta[0] * score1 + self.beta[1] * score2
if self.min_score is None:
score = self.nonlinearity(score)
else:
score = softmax(score, batch)
perm = topk(score, self.ratio, batch, self.min_score)
x = x[perm] * score[perm].view(-1, 1)
x = self.multiplier * x if self.multiplier != 1 else x
edge_index = torch.stack([col, row], dim=0)
edge_index, edge_attr = filter_adj(edge_index, edge_weight, perm,
num_nodes=score.size(0))
return x, edge_index, edge_attr, batch[perm], perm, score[perm]
def __repr__(self) -> str:
if self.min_score is None:
ratio = f'ratio={self.ratio}'
else:
ratio = f'min_score={self.min_score}'
return (f'{self.__class__.__name__}({self.in_channels}, {ratio}, '
f'multiplier={self.multiplier})')