Source code for torch_geometric.nn.pool.mem_pool

from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Conv2d, KLDivLoss, Linear, Parameter

from torch_geometric.utils import to_dense_batch

EPS = 1e-15


[docs]class MemPooling(torch.nn.Module): r"""Memory based pooling layer from `"Memory-Based Graph Networks" <https://arxiv.org/abs/2002.09518>`_ paper, which learns a coarsened graph representation based on soft cluster assignments. .. math:: S_{i,j}^{(h)} &= \frac{ (1+{\| \mathbf{x}_i-\mathbf{k}^{(h)}_j \|}^2 / \tau)^{ -\frac{1+\tau}{2}}}{ \sum_{k=1}^K (1 + {\| \mathbf{x}_i-\mathbf{k}^{(h)}_k \|}^2 / \tau)^{ -\frac{1+\tau}{2}}} \mathbf{S} &= \textrm{softmax}(\textrm{Conv2d} (\Vert_{h=1}^H \mathbf{S}^{(h)})) \in \mathbb{R}^{N \times K} \mathbf{X}^{\prime} &= \mathbf{S}^{\top} \mathbf{X} \mathbf{W} \in \mathbb{R}^{K \times F^{\prime}} where :math:`H` denotes the number of heads, and :math:`K` denotes the number of clusters. Args: in_channels (int): Size of each input sample :math:`F`. out_channels (int): Size of each output sample :math:`F^{\prime}`. heads (int): The number of heads :math:`H`. num_clusters (int): number of clusters :math:`K` per head. tau (int, optional): The temperature :math:`\tau`. (default: :obj:`1.`) """ def __init__(self, in_channels: int, out_channels: int, heads: int, num_clusters: int, tau: float = 1.): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.num_clusters = num_clusters self.tau = tau self.k = Parameter(torch.empty(heads, num_clusters, in_channels)) self.conv = Conv2d(heads, 1, kernel_size=1, padding=0, bias=False) self.lin = Linear(in_channels, out_channels, bias=False) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" torch.nn.init.uniform_(self.k.data, -1., 1.) self.conv.reset_parameters() self.lin.reset_parameters()
[docs] @staticmethod def kl_loss(S: Tensor) -> Tensor: r"""The additional KL divergence-based loss. .. math:: P_{i,j} &= \frac{S_{i,j}^2 / \sum_{n=1}^N S_{n,j}}{\sum_{k=1}^K S_{i,k}^2 / \sum_{n=1}^N S_{n,k}} \mathcal{L}_{\textrm{KL}} &= \textrm{KLDiv}(\mathbf{P} \Vert \mathbf{S}) """ S_2 = S**2 P = S_2 / S.sum(dim=1, keepdim=True) denom = P.sum(dim=2, keepdim=True) denom[S.sum(dim=2, keepdim=True) == 0.0] = 1.0 P /= denom loss = KLDivLoss(reduction='batchmean', log_target=False) return loss(S.clamp(EPS).log(), P.clamp(EPS))
[docs] def forward( self, x: Tensor, batch: Optional[Tensor] = None, mask: Optional[Tensor] = None, max_num_nodes: Optional[int] = None, batch_size: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: r"""Forward pass. Args: x (torch.Tensor): The node feature tensor of shape :math:`\mathbf{X} \in \mathbb{R}^{N \times F}` or :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Should not be provided in case node features already have shape :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`. (default: :obj:`None`) mask (torch.Tensor, optional): A mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}`, which indicates valid nodes for each graph when using node features of shape :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`. (default: :obj:`None`) max_num_nodes (int, optional): The size of the :math:`B` node dimension. Automatically calculated if not given. (default: :obj:`None`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) """ if x.dim() <= 2: x, mask = to_dense_batch(x, batch, max_num_nodes=max_num_nodes, batch_size=batch_size) elif mask is None: mask = x.new_ones((x.size(0), x.size(1)), dtype=torch.bool) (B, N, _), H, K = x.size(), self.heads, self.num_clusters dist = torch.cdist(self.k.view(H * K, -1), x.view(B * N, -1), p=2)**2 dist = (1. + dist / self.tau).pow(-(self.tau + 1.0) / 2.0) dist = dist.view(H, K, B, N).permute(2, 0, 3, 1) # [B, H, N, K] S = dist / dist.sum(dim=-1, keepdim=True) S = self.conv(S).squeeze(dim=1).softmax(dim=-1) # [B, N, K] S = S * mask.view(B, N, 1) x = self.lin(S.transpose(1, 2) @ x) return x, S
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads}, ' f'num_clusters={self.num_clusters})')