# torch_geometric.nn.dense.DMoNPooling

class DMoNPooling(channels: Union[int, List[int]], k: int, dropout: float = 0.0)[source]

Bases: Module

The spectral modularity pooling operator from the “Graph Clustering with Graph Neural Networks” paper

\begin{align}\begin{aligned}\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X}\\\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})\end{aligned}\end{align}

based on dense learned assignments $$\mathbf{S} \in \mathbb{R}^{B \times N \times C}$$. Returns the learned cluster assignment matrix, the pooled node feature matrix, the coarsened symmetrically normalized adjacency matrix, and three auxiliary objectives: (1) The spectral loss

$\mathcal{L}_s = - \frac{1}{2m} \cdot{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{B} \mathbf{S})}$

where $$\mathbf{B}$$ is the modularity matrix, (2) the orthogonality loss

$\mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}} {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}} \right\|}_F$

where $$C$$ is the number of clusters, and (3) the cluster loss

$\mathcal{L}_c = \frac{\sqrt{C}}{n} {\left\|\sum_i\mathbf{C_i}^{\top}\right\|}_F - 1.$

Note

For an example of using DMoNPooling, see examples/proteins_dmon_pool.py.

Parameters
• channels (int or List[int]) – Size of each input sample. If given as a list, will construct an MLP based on the given feature sizes.

• k (int) – The number of clusters.

• dropout (float, optional) – Dropout probability. (default: 0.0)

reset_parameters()[source]

Resets all learnable parameters of the module.

• x (torch.Tensor) – Node feature tensor $$\mathbf{X} \in \mathbb{R}^{B \times N \times F}$$, with batch-size $$B$$, (maximum) number of nodes $$N$$ for each graph, and feature dimension $$F$$. Note that the cluster assignment matrix $$\mathbf{S} \in \mathbb{R}^{B \times N \times C}$$ is being created within this method.
• adj (torch.Tensor) – Adjacency tensor $$\mathbf{A} \in \mathbb{R}^{B \times N \times N}$$.
• mask (torch.Tensor, optional) – Mask matrix $$\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}$$ indicating the valid nodes for each graph. (default: None)