torch_geometric.nn.dense.DMoNPooling

class DMoNPooling(channels: Union[int, List[int]], k: int, dropout: float = 0.0)[source]

Bases: Module

The spectral modularity pooling operator from the “Graph Clustering with Graph Neural Networks” paper.

\[ \begin{align}\begin{aligned}\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X}\\\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})\end{aligned}\end{align} \]

based on dense learned assignments \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\). Returns the learned cluster assignment matrix, the pooled node feature matrix, the coarsened symmetrically normalized adjacency matrix, and three auxiliary objectives: (1) The spectral loss

\[\mathcal{L}_s = - \frac{1}{2m} \cdot{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{B} \mathbf{S})}\]

where \(\mathbf{B}\) is the modularity matrix, (2) the orthogonality loss

\[\mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}} {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}} \right\|}_F\]

where \(C\) is the number of clusters, and (3) the cluster loss

\[\mathcal{L}_c = \frac{\sqrt{C}}{n} {\left\|\sum_i\mathbf{C_i}^{\top}\right\|}_F - 1.\]

Note

For an example of using DMoNPooling, see examples/proteins_dmon_pool.py.

Parameters:
  • channels (int or List[int]) – Size of each input sample. If given as a list, will construct an MLP based on the given feature sizes.

  • k (int) – The number of clusters.

  • dropout (float, optional) – Dropout probability. (default: 0.0)

reset_parameters()[source]

Resets all learnable parameters of the module.

forward(x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor][source]

Forward pass.

Parameters:
  • x (torch.Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\). Note that the cluster assignment matrix \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\) is being created within this method.

  • adj (torch.Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\).

  • mask (torch.Tensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

Return type:

(torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)