torch_geometric.nn.dense.DMoNPooling
- class DMoNPooling(channels: Union[int, List[int]], k: int, dropout: float = 0.0)[source]
Bases:
Module
The spectral modularity pooling operator from the “Graph Clustering with Graph Neural Networks” paper.
\[ \begin{align}\begin{aligned}\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X}\\\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})\end{aligned}\end{align} \]based on dense learned assignments \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\). Returns the learned cluster assignment matrix, the pooled node feature matrix, the coarsened symmetrically normalized adjacency matrix, and three auxiliary objectives: (1) The spectral loss
\[\mathcal{L}_s = - \frac{1}{2m} \cdot{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{B} \mathbf{S})}\]where \(\mathbf{B}\) is the modularity matrix, (2) the orthogonality loss
\[\mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}} {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}} \right\|}_F\]where \(C\) is the number of clusters, and (3) the cluster loss
\[\mathcal{L}_c = \frac{\sqrt{C}}{n} {\left\|\sum_i\mathbf{C_i}^{\top}\right\|}_F - 1.\]Note
For an example of using
DMoNPooling
, see examples/proteins_dmon_pool.py.- Parameters:
- forward(x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor] [source]
Forward pass.
- Parameters:
x (torch.Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\). Note that the cluster assignment matrix \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\) is being created within this method.
adj (torch.Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\).
mask (torch.Tensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default:
None
)
- Return type:
(
torch.Tensor
,torch.Tensor
,torch.Tensor
,torch.Tensor
,torch.Tensor
,torch.Tensor
)