Source code for torch_geometric.nn.dense.mincut_pool

from typing import Optional, Tuple

import torch
from torch import Tensor

[docs]def dense_mincut_pool(
x: Tensor,
s: Tensor,
temp: float = 1.0,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
r"""The MinCut pooling operator from the "Spectral Clustering in Graph
Neural Networks for Graph Pooling" <https://arxiv.org/abs/1907.00481>_
paper.

.. math::
\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
\mathbf{X}

\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
\mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})

based on dense learned assignments :math:\mathbf{S} \in \mathbb{R}^{B
\times N \times C}.
Returns the pooled node feature matrix, the coarsened and symmetrically
normalized adjacency matrix and two auxiliary objectives: (1) The MinCut
loss

.. math::
\mathcal{L}_c = - \frac{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{A}
\mathbf{S})} {\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{D}
\mathbf{S})}

where :math:\mathbf{D} is the degree matrix, and (2) the orthogonality
loss

.. math::
\mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}}
{{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}}
\right\|}_F.

Args:
x (torch.Tensor): Node feature tensor
:math:\mathbf{X} \in \mathbb{R}^{B \times N \times F}, with
batch-size :math:B, (maximum) number of nodes :math:N for
each graph, and feature dimension :math:F.
:math:\mathbf{A} \in \mathbb{R}^{B \times N \times N}.
s (torch.Tensor): Assignment tensor
:math:\mathbf{S} \in \mathbb{R}^{B \times N \times C}
with number of clusters :math:C.
The softmax does not have to be applied before-hand, since it is
executed within this method.
:math:\mathbf{M} \in {\{ 0, 1 \}}^{B \times N} indicating
the valid nodes for each graph. (default: :obj:None)
temp (float, optional): Temperature parameter for softmax function.
(default: :obj:1.0)

:rtype: (:class:torch.Tensor, :class:torch.Tensor,
:class:torch.Tensor, :class:torch.Tensor)
"""
x = x.unsqueeze(0) if x.dim() == 2 else x
s = s.unsqueeze(0) if s.dim() == 2 else s

(batch_size, num_nodes, _), k = x.size(), s.size(-1)

s = torch.softmax(s / temp if temp != 1.0 else s, dim=-1)

out = torch.matmul(s.transpose(1, 2), x)

# MinCut regularization.
d = _rank3_diag(d_flat)
mincut_den = _rank3_trace(
torch.matmul(torch.matmul(s.transpose(1, 2), d), s))
mincut_loss = -(mincut_num / mincut_den)
mincut_loss = torch.mean(mincut_loss)

# Orthogonality regularization.
ss = torch.matmul(s.transpose(1, 2), s)
i_s = torch.eye(k).type_as(ss)
ortho_loss = torch.norm(
ss / torch.norm(ss, dim=(-1, -2), keepdim=True) -
i_s / torch.norm(i_s), dim=(-1, -2))
ortho_loss = torch.mean(ortho_loss)

EPS = 1e-15

# Fix and normalize coarsened adjacency matrix.
d = torch.sqrt(d)[:, None] + EPS