Source code for torch_geometric.nn.dense.diff_pool

from typing import Optional, Tuple

import torch
from torch import Tensor

[docs]def dense_diff_pool( x: Tensor, adj: Tensor, s: Tensor, mask: Optional[Tensor] = None, normalize: bool = True, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: r"""The differentiable pooling operator from the `"Hierarchical Graph Representation Learning with Differentiable Pooling" <>`_ paper. .. math:: \mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X} \mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S}) based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}`. Returns the pooled node feature matrix, the coarsened adjacency matrix and two auxiliary objectives: (1) The link prediction loss .. math:: \mathcal{L}_{LP} = {\| \mathbf{A} - \mathrm{softmax}(\mathbf{S}) {\mathrm{softmax}(\mathbf{S})}^{\top} \|}_F, and (2) the entropy regularization .. math:: \mathcal{L}_E = \frac{1}{N} \sum_{n=1}^N H(\mathbf{S}_n). Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. adj (torch.Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. s (torch.Tensor): Assignment tensor :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` with number of clusters :math:`C`. The softmax does not have to be applied before-hand, since it is executed within this method. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) normalize (bool, optional): If set to :obj:`False`, the link prediction loss is not divided by :obj:`adj.numel()`. (default: :obj:`True`) :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`, :class:`torch.Tensor`) """ x = x.unsqueeze(0) if x.dim() == 2 else x adj = adj.unsqueeze(0) if adj.dim() == 2 else adj s = s.unsqueeze(0) if s.dim() == 2 else s batch_size, num_nodes, _ = x.size() s = torch.softmax(s, dim=-1) if mask is not None: mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) x, s = x * mask, s * mask out = torch.matmul(s.transpose(1, 2), x) out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) link_loss = adj - torch.matmul(s, s.transpose(1, 2)) link_loss = torch.norm(link_loss, p=2) if normalize is True: link_loss = link_loss / adj.numel() ent_loss = (-s * torch.log(s + 1e-15)).sum(dim=-1).mean() return out, out_adj, link_loss, ent_loss