Source code for torch_geometric.nn.pool.edge_pool

from typing import Callable, List, NamedTuple, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.utils import coalesce, scatter, softmax

class UnpoolInfo(NamedTuple):
    edge_index: Tensor
    cluster: Tensor
    batch: Tensor
    new_edge_score: Tensor

[docs]class EdgePooling(torch.nn.Module): r"""The edge pooling operator from the `"Towards Graph Pooling by Edge Contraction" <>`__ and `"Edge Contraction Pooling for Graph Neural Networks" <>`__ papers. In short, a score is computed for each edge. Edges are contracted iteratively according to that score unless one of their nodes has already been part of a contracted edge. To duplicate the configuration from the `"Towards Graph Pooling by Edge Contraction" <>`__ paper, use either :func:`EdgePooling.compute_edge_score_softmax` or :func:`EdgePooling.compute_edge_score_tanh`, and set :obj:`add_to_edge_score` to :obj:`0.0`. To duplicate the configuration from the `"Edge Contraction Pooling for Graph Neural Networks" <>`__ paper, set :obj:`dropout` to :obj:`0.2`. Args: in_channels (int): Size of each input sample. edge_score_method (callable, optional): The function to apply to compute the edge score from raw edge scores. By default, this is the softmax over all incoming edges for each node. This function takes in a :obj:`raw_edge_score` tensor of shape :obj:`[num_nodes]`, an :obj:`edge_index` tensor and the number of nodes :obj:`num_nodes`, and produces a new tensor of the same size as :obj:`raw_edge_score` describing normalized edge scores. Included functions are :func:`EdgePooling.compute_edge_score_softmax`, :func:`EdgePooling.compute_edge_score_tanh`, and :func:`EdgePooling.compute_edge_score_sigmoid`. (default: :func:`EdgePooling.compute_edge_score_softmax`) dropout (float, optional): The probability with which to drop edge scores during training. (default: :obj:`0.0`) add_to_edge_score (float, optional): A value to be added to each computed edge score. Adding this greatly helps with unpooling stability. (default: :obj:`0.5`) """ def __init__( self, in_channels: int, edge_score_method: Optional[Callable] = None, dropout: Optional[float] = 0.0, add_to_edge_score: float = 0.5, ): super().__init__() self.in_channels = in_channels if edge_score_method is None: edge_score_method = self.compute_edge_score_softmax self.compute_edge_score = edge_score_method self.add_to_edge_score = add_to_edge_score self.dropout = dropout self.lin = torch.nn.Linear(2 * in_channels, 1) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin.reset_parameters()
[docs] @staticmethod def compute_edge_score_softmax( raw_edge_score: Tensor, edge_index: Tensor, num_nodes: int, ) -> Tensor: r"""Normalizes edge scores via softmax application.""" return softmax(raw_edge_score, edge_index[1], num_nodes=num_nodes)
[docs] @staticmethod def compute_edge_score_tanh( raw_edge_score: Tensor, edge_index: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tensor: r"""Normalizes edge scores via hyperbolic tangent application.""" return torch.tanh(raw_edge_score)
[docs] @staticmethod def compute_edge_score_sigmoid( raw_edge_score: Tensor, edge_index: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> Tensor: r"""Normalizes edge scores via sigmoid application.""" return torch.sigmoid(raw_edge_score)
[docs] def forward( self, x: Tensor, edge_index: Tensor, batch: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]: r"""Forward pass. Args: x (torch.Tensor): The node features. edge_index (torch.Tensor): The edge indices. batch (torch.Tensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Return types: * **x** *(torch.Tensor)* - The pooled node features. * **edge_index** *(torch.Tensor)* - The coarsened edge indices. * **batch** *(torch.Tensor)* - The coarsened batch vector. * **unpool_info** *(UnpoolInfo)* - Information that is consumed by :func:`EdgePooling.unpool` for unpooling. """ e =[x[edge_index[0]], x[edge_index[1]]], dim=-1) e = self.lin(e).view(-1) e = F.dropout(e, p=self.dropout, e = self.compute_edge_score(e, edge_index, x.size(0)) e = e + self.add_to_edge_score x, edge_index, batch, unpool_info = self._merge_edges( x, edge_index, batch, e) return x, edge_index, batch, unpool_info
def _merge_edges( self, x: Tensor, edge_index: Tensor, batch: Tensor, edge_score: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]: cluster = torch.empty_like(batch) perm: List[int] = torch.argsort(edge_score, descending=True).tolist() # Iterate through all edges, selecting it if it is not incident to # another already chosen edge. mask = torch.ones(x.size(0), dtype=torch.bool) i = 0 new_edge_indices: List[int] = [] edge_index_cpu = edge_index.cpu() for edge_idx in perm: source = int(edge_index_cpu[0, edge_idx]) if not bool(mask[source]): continue target = int(edge_index_cpu[1, edge_idx]) if not bool(mask[target]): continue new_edge_indices.append(edge_idx) cluster[source] = i mask[source] = False if source != target: cluster[target] = i mask[target] = False i += 1 # The remaining nodes are simply kept: j = int(mask.sum()) cluster[mask] = torch.arange(i, i + j, device=x.device) i += j # We compute the new features as an addition of the old ones. new_x = scatter(x, cluster, dim=0, dim_size=i, reduce='sum') new_edge_score = edge_score[new_edge_indices] if int(mask.sum()) > 0: remaining_score = x.new_ones( (new_x.size(0) - len(new_edge_indices), )) new_edge_score =[new_edge_score, remaining_score]) new_x = new_x * new_edge_score.view(-1, 1) new_edge_index = coalesce(cluster[edge_index], num_nodes=new_x.size(0)) new_batch = x.new_empty(new_x.size(0), dtype=torch.long) new_batch = new_batch.scatter_(0, cluster, batch) unpool_info = UnpoolInfo(edge_index=edge_index, cluster=cluster, batch=batch, new_edge_score=new_edge_score) return new_x, new_edge_index, new_batch, unpool_info
[docs] def unpool( self, x: Tensor, unpool_info: UnpoolInfo, ) -> Tuple[Tensor, Tensor, Tensor]: r"""Unpools a previous edge pooling step. For unpooling, :obj:`x` should be of same shape as those produced by this layer's :func:`forward` function. Then, it will produce an unpooled :obj:`x` in addition to :obj:`edge_index` and :obj:`batch`. Args: x (torch.Tensor): The node features. unpool_info (UnpoolInfo): Information that has been produced by :func:`EdgePooling.forward`. Return types: * **x** *(torch.Tensor)* - The unpooled node features. * **edge_index** *(torch.Tensor)* - The new edge indices. * **batch** *(torch.Tensor)* - The new batch vector. """ new_x = x / unpool_info.new_edge_score.view(-1, 1) new_x = new_x[unpool_info.cluster] return new_x, unpool_info.edge_index, unpool_info.batch
def __repr__(self) -> str: return f'{self.__class__.__name__}({self.in_channels})'