Source code for torch_geometric.nn.dense.dmon_pool

from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn.dense.mincut_pool import _rank3_trace

EPS = 1e-15

[docs]class DMoNPooling(torch.nn.Module): r"""The spectral modularity pooling operator from the `"Graph Clustering with Graph Neural Networks" <>`_ paper. .. math:: \mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X} \mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S}) based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}`. Returns the learned cluster assignment matrix, the pooled node feature matrix, the coarsened symmetrically normalized adjacency matrix, and three auxiliary objectives: (1) The spectral loss .. math:: \mathcal{L}_s = - \frac{1}{2m} \cdot{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{B} \mathbf{S})} where :math:`\mathbf{B}` is the modularity matrix, (2) the orthogonality loss .. math:: \mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}} {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}} \right\|}_F where :math:`C` is the number of clusters, and (3) the cluster loss .. math:: \mathcal{L}_c = \frac{\sqrt{C}}{n} {\left\|\sum_i\mathbf{C_i}^{\top}\right\|}_F - 1. .. note:: For an example of using :class:`DMoNPooling`, see `examples/ < /master/examples/>`_. Args: channels (int or List[int]): Size of each input sample. If given as a list, will construct an MLP based on the given feature sizes. k (int): The number of clusters. dropout (float, optional): Dropout probability. (default: :obj:`0.0`) """ def __init__(self, channels: Union[int, List[int]], k: int, dropout: float = 0.0): super().__init__() if isinstance(channels, int): channels = [channels] from torch_geometric.nn.models.mlp import MLP self.mlp = MLP(channels + [k], act=None, norm=None) self.dropout = dropout self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.mlp.reset_parameters()
[docs] def forward( self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. Note that the cluster assignment matrix :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` is being created within this method. adj (torch.Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`) """ x = x.unsqueeze(0) if x.dim() == 2 else x adj = adj.unsqueeze(0) if adj.dim() == 2 else adj s = self.mlp(x) s = F.dropout(s, self.dropout, s = torch.softmax(s, dim=-1) (batch_size, num_nodes, _), C = x.size(), s.size(-1) if mask is None: mask = torch.ones(batch_size, num_nodes, dtype=torch.bool, device=x.device) mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) x, s = x * mask, s * mask out = F.selu(torch.matmul(s.transpose(1, 2), x)) out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) # Spectral loss: degrees = torch.einsum('ijk->ij', adj) # B X N degrees = degrees.unsqueeze(-1) * mask # B x N x 1 degrees_t = degrees.transpose(1, 2) # B x 1 x N m = torch.einsum('ijk->i', degrees) / 2 # B m_expand = m.view(-1, 1, 1).expand(-1, C, C) # B x C x C ca = torch.matmul(s.transpose(1, 2), degrees) # B x C x 1 cb = torch.matmul(degrees_t, s) # B x 1 x C normalizer = torch.matmul(ca, cb) / 2 / m_expand decompose = out_adj - normalizer spectral_loss = -_rank3_trace(decompose) / 2 / m spectral_loss = spectral_loss.mean() # Orthogonality regularization: ss = torch.matmul(s.transpose(1, 2), s) i_s = torch.eye(C).type_as(ss) ortho_loss = torch.norm( ss / torch.norm(ss, dim=(-1, -2), keepdim=True) - i_s / torch.norm(i_s), dim=(-1, -2)) ortho_loss = ortho_loss.mean() # Cluster loss: cluster_size = torch.einsum('ijk->ik', s) # B x C cluster_loss = torch.norm(input=cluster_size, dim=1) cluster_loss = cluster_loss / mask.sum(dim=1) * torch.norm(i_s) - 1 cluster_loss = cluster_loss.mean() # Fix and normalize coarsened adjacency matrix: ind = torch.arange(C, device=out_adj.device) out_adj[:, ind, ind] = 0 d = torch.einsum('ijk->ij', out_adj) d = torch.sqrt(d)[:, None] + EPS out_adj = (out_adj / d) / d.transpose(1, 2) return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.mlp.in_channels}, ' f'num_clusters={self.mlp.out_channels})')