Source code for torch_geometric.nn.pool.asap

from typing import Callable, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear

from torch_geometric.nn import LEConv
from torch_geometric.nn.pool.select import SelectTopK
from torch_geometric.utils import (
    add_remaining_self_loops,
    remove_self_loops,
    scatter,
    softmax,
    to_edge_index,
    to_torch_coo_tensor,
    to_torch_csr_tensor,
)


[docs]class ASAPooling(torch.nn.Module): r"""The Adaptive Structure Aware Pooling operator from the `"ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations" <https://arxiv.org/abs/1911.07979>`_ paper. Args: in_channels (int): Size of each input sample. ratio (float or int): Graph pooling ratio, which is used to compute :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`, or the value of :math:`k` itself, depending on whether the type of :obj:`ratio` is :obj:`float` or :obj:`int`. (default: :obj:`0.5`) GNN (torch.nn.Module, optional): A graph neural network layer for using intra-cluster properties. Especially helpful for graphs with higher degree of neighborhood (one of :class:`torch_geometric.nn.conv.GraphConv`, :class:`torch_geometric.nn.conv.GCNConv` or any GNN which supports the :obj:`edge_weight` parameter). (default: :obj:`None`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) add_self_loops (bool, optional): If set to :obj:`True`, will add self loops to the new graph connectivity. (default: :obj:`False`) **kwargs (optional): Additional parameters for initializing the graph neural network layer. """ def __init__(self, in_channels: int, ratio: Union[float, int] = 0.5, GNN: Optional[Callable] = None, dropout: float = 0.0, negative_slope: float = 0.2, add_self_loops: bool = False, **kwargs): super().__init__() self.in_channels = in_channels self.ratio = ratio self.negative_slope = negative_slope self.dropout = dropout self.GNN = GNN self.add_self_loops = add_self_loops self.lin = Linear(in_channels, in_channels) self.att = Linear(2 * in_channels, 1) self.gnn_score = LEConv(self.in_channels, 1) if self.GNN is not None: self.gnn_intra_cluster = GNN(self.in_channels, self.in_channels, **kwargs) else: self.gnn_intra_cluster = None self.select = SelectTopK(1, ratio) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin.reset_parameters() self.att.reset_parameters() self.gnn_score.reset_parameters() if self.gnn_intra_cluster is not None: self.gnn_intra_cluster.reset_parameters() self.select.reset_parameters()
[docs] def forward( self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None, batch: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor]: r"""Forward pass. Args: x (torch.Tensor): The node feature matrix. edge_index (torch.Tensor): The edge indices. edge_weight (torch.Tensor, optional): The edge weights. (default: :obj:`None`) batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) Return types: * **x** (*torch.Tensor*): The pooled node embeddings. * **edge_index** (*torch.Tensor*): The coarsened edge indices. * **edge_weight** (*torch.Tensor, optional*): The coarsened edge weights. * **batch** (*torch.Tensor*): The coarsened batch vector. * **index** (*torch.Tensor*): The top-:math:`k` node indices of nodes which are kept after pooling. """ N = x.size(0) edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value=1., num_nodes=N) if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x x_pool = x if self.gnn_intra_cluster is not None: x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight) x_pool_j = x_pool[edge_index[0]] x_q = scatter(x_pool_j, edge_index[1], dim=0, reduce='max') x_q = self.lin(x_q)[edge_index[1]] score = self.att(torch.cat([x_q, x_pool_j], dim=-1)).view(-1) score = F.leaky_relu(score, self.negative_slope) score = softmax(score, edge_index[1], num_nodes=N) # Sample attention coefficients stochastically. score = F.dropout(score, p=self.dropout, training=self.training) v_j = x[edge_index[0]] * score.view(-1, 1) x = scatter(v_j, edge_index[1], dim=0, reduce='sum') # Cluster selection. fitness = self.gnn_score(x, edge_index).sigmoid().view(-1) perm = self.select(fitness, batch).node_index x = x[perm] * fitness[perm].view(-1, 1) batch = batch[perm] # Graph coarsening. A = to_torch_csr_tensor(edge_index, edge_weight, size=(N, N)) S = to_torch_coo_tensor(edge_index, score, size=(N, N)) S = S.index_select(1, perm).to_sparse_csr() A = S.t().to_sparse_csr() @ (A @ S) if edge_weight is None: edge_index, _ = to_edge_index(A) else: edge_index, edge_weight = to_edge_index(A) if self.add_self_loops: edge_index, edge_weight = add_remaining_self_loops( edge_index, edge_weight, num_nodes=A.size(0)) else: edge_index, edge_weight = remove_self_loops( edge_index, edge_weight) return x, edge_index, edge_weight, batch, perm
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'ratio={self.ratio})')