from typing import Optional, Tuple
import torch
from torch import Tensor
[docs]def dense_diff_pool(
x: Tensor,
adj: Tensor,
s: Tensor,
mask: Optional[Tensor] = None,
normalize: bool = True,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
r"""The differentiable pooling operator from the `"Hierarchical Graph
Representation Learning with Differentiable Pooling"
<https://arxiv.org/abs/1806.08804>`_ paper.
.. math::
\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
\mathbf{X}
\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
\mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})
based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B
\times N \times C}`.
Returns the pooled node feature matrix, the coarsened adjacency matrix and
two auxiliary objectives: (1) The link prediction loss
.. math::
\mathcal{L}_{LP} = {\| \mathbf{A} -
\mathrm{softmax}(\mathbf{S}) {\mathrm{softmax}(\mathbf{S})}^{\top}
\|}_F,
and (2) the entropy regularization
.. math::
\mathcal{L}_E = \frac{1}{N} \sum_{n=1}^N H(\mathbf{S}_n).
Args:
x (torch.Tensor): Node feature tensor
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
each graph, and feature dimension :math:`F`.
adj (torch.Tensor): Adjacency tensor
:math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
s (torch.Tensor): Assignment tensor
:math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}`
with number of clusters :math:`C`.
The softmax does not have to be applied before-hand, since it is
executed within this method.
mask (torch.Tensor, optional): Mask matrix
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
the valid nodes for each graph. (default: :obj:`None`)
normalize (bool, optional): If set to :obj:`False`, the link
prediction loss is not divided by :obj:`adj.numel()`.
(default: :obj:`True`)
:rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`,
:class:`torch.Tensor`, :class:`torch.Tensor`)
"""
x = x.unsqueeze(0) if x.dim() == 2 else x
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
s = s.unsqueeze(0) if s.dim() == 2 else s
batch_size, num_nodes, _ = x.size()
s = torch.softmax(s, dim=-1)
if mask is not None:
mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
x, s = x * mask, s * mask
out = torch.matmul(s.transpose(1, 2), x)
out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)
link_loss = adj - torch.matmul(s, s.transpose(1, 2))
link_loss = torch.norm(link_loss, p=2)
if normalize is True:
link_loss = link_loss / adj.numel()
ent_loss = (-s * torch.log(s + 1e-15)).sum(dim=-1).mean()
return out, out_adj, link_loss, ent_loss