# Source code for torch_geometric.nn.dense.dmon_pool

from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn.dense.mincut_pool import _rank3_trace

EPS = 1e-15

[docs]class DMoNPooling(torch.nn.Module):
r"""The spectral modularity pooling operator from the "Graph Clustering
with Graph Neural Networks" <https://arxiv.org/abs/2006.16904>_ paper.

.. math::
\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
\mathbf{X}

\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
\mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})

based on dense learned assignments :math:\mathbf{S} \in \mathbb{R}^{B
\times N \times C}.
Returns the learned cluster assignment matrix, the pooled node feature
matrix, the coarsened symmetrically normalized adjacency matrix, and three
auxiliary objectives: (1) The spectral loss

.. math::
\mathcal{L}_s = - \frac{1}{2m}
\cdot{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{B} \mathbf{S})}

where :math:\mathbf{B} is the modularity matrix, (2) the orthogonality
loss

.. math::
\mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}}
{{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}}
\right\|}_F

where :math:C is the number of clusters, and (3) the cluster loss

.. math::
\mathcal{L}_c = \frac{\sqrt{C}}{n}
{\left\|\sum_i\mathbf{C_i}^{\top}\right\|}_F - 1.

.. note::

For an example of using :class:DMoNPooling, see
examples/proteins_dmon_pool.py
<https://github.com/pyg-team/pytorch_geometric/blob
/master/examples/proteins_dmon_pool.py>_.

Args:
channels (int or List[int]): Size of each input sample. If given as a
list, will construct an MLP based on the given feature sizes.
k (int): The number of clusters.
dropout (float, optional): Dropout probability. (default: :obj:0.0)
"""
def __init__(self, channels: Union[int, List[int]], k: int,
dropout: float = 0.0):
super().__init__()

if isinstance(channels, int):
channels = [channels]

from torch_geometric.nn.models.mlp import MLP
self.mlp = MLP(channels + [k], act=None, norm=None)

self.dropout = dropout

self.reset_parameters()

[docs]    def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.mlp.reset_parameters()

[docs]    def forward(
self,
x: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
r"""Forward pass.

Args:
x (torch.Tensor): Node feature tensor
:math:\mathbf{X} \in \mathbb{R}^{B \times N \times F}, with
batch-size :math:B, (maximum) number of nodes :math:N for
each graph, and feature dimension :math:F.
Note that the cluster assignment matrix
:math:\mathbf{S} \in \mathbb{R}^{B \times N \times C} is
being created within this method.
:math:\mathbf{A} \in \mathbb{R}^{B \times N \times N}.
:math:\mathbf{M} \in {\{ 0, 1 \}}^{B \times N} indicating
the valid nodes for each graph. (default: :obj:None)

:rtype: (:class:torch.Tensor, :class:torch.Tensor,
:class:torch.Tensor, :class:torch.Tensor,
:class:torch.Tensor, :class:torch.Tensor)
"""
x = x.unsqueeze(0) if x.dim() == 2 else x

s = self.mlp(x)
s = F.dropout(s, self.dropout, training=self.training)
s = torch.softmax(s, dim=-1)

(batch_size, num_nodes, _), C = x.size(), s.size(-1)

device=x.device)

out = F.selu(torch.matmul(s.transpose(1, 2), x))

# Spectral loss:
degrees = torch.einsum('ijk->ij', adj)  # B X N
degrees = degrees.unsqueeze(-1) * mask  # B x N x 1
degrees_t = degrees.transpose(1, 2)  # B x 1 x N

m = torch.einsum('ijk->i', degrees) / 2  # B
m_expand = m.view(-1, 1, 1).expand(-1, C, C)  # B x C x C

ca = torch.matmul(s.transpose(1, 2), degrees)  # B x C x 1
cb = torch.matmul(degrees_t, s)  # B x 1 x C

normalizer = torch.matmul(ca, cb) / 2 / m_expand
spectral_loss = -_rank3_trace(decompose) / 2 / m
spectral_loss = spectral_loss.mean()

# Orthogonality regularization:
ss = torch.matmul(s.transpose(1, 2), s)
i_s = torch.eye(C).type_as(ss)
ortho_loss = torch.norm(
ss / torch.norm(ss, dim=(-1, -2), keepdim=True) -
i_s / torch.norm(i_s), dim=(-1, -2))
ortho_loss = ortho_loss.mean()

# Cluster loss:
cluster_size = torch.einsum('ijk->ik', s)  # B x C
cluster_loss = torch.norm(input=cluster_size, dim=1)
cluster_loss = cluster_loss / mask.sum(dim=1) * torch.norm(i_s) - 1
cluster_loss = cluster_loss.mean()

# Fix and normalize coarsened adjacency matrix: