class MemPooling(in_channels: int, out_channels: int, heads: int, num_clusters: int, tau: float = 1.0)[source]

Bases: Module

Memory based pooling layer from “Memory-Based Graph Networks” paper, which learns a coarsened graph representation based on soft cluster assignments.

\[ \begin{align}\begin{aligned}S_{i,j}^{(h)} &= \frac{ (1+{\| \mathbf{x}_i-\mathbf{k}^{(h)}_j \|}^2 / \tau)^{ -\frac{1+\tau}{2}}}{ \sum_{k=1}^K (1 + {\| \mathbf{x}_i-\mathbf{k}^{(h)}_k \|}^2 / \tau)^{ -\frac{1+\tau}{2}}}\\\mathbf{S} &= \textrm{softmax}(\textrm{Conv2d} (\Vert_{h=1}^H \mathbf{S}^{(h)})) \in \mathbb{R}^{N \times K}\\\mathbf{X}^{\prime} &= \mathbf{S}^{\top} \mathbf{X} \mathbf{W} \in \mathbb{R}^{K \times F^{\prime}}\end{aligned}\end{align} \]

where \(H\) denotes the number of heads, and \(K\) denotes the number of clusters.

  • in_channels (int) – Size of each input sample \(F\).

  • out_channels (int) – Size of each output sample \(F^{\prime}\).

  • heads (int) – The number of heads \(H\).

  • num_clusters (int) – number of clusters \(K\) per head.

  • tau (int, optional) – The temperature \(\tau\). (default: 1.)


Resets all learnable parameters of the module.

static kl_loss(S: Tensor) Tensor[source]

The additional KL divergence-based loss.

\[ \begin{align}\begin{aligned}P_{i,j} &= \frac{S_{i,j}^2 / \sum_{n=1}^N S_{n,j}}{\sum_{k=1}^K S_{i,k}^2 / \sum_{n=1}^N S_{n,k}}\\\mathcal{L}_{\textrm{KL}} &= \textrm{KLDiv}(\mathbf{P} \Vert \mathbf{S})\end{aligned}\end{align} \]
forward(x: Tensor, batch: Optional[Tensor] = None, mask: Optional[Tensor] = None, max_num_nodes: Optional[int] = None, batch_size: Optional[int] = None) Tuple[Tensor, Tensor][source]

Forward pass.

  • x (torch.Tensor) – The node feature tensor of shape \(\mathbf{X} \in \mathbb{R}^{N \times F}\) or \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\).

  • batch (torch.Tensor, optional) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. Should not be provided in case node features already have shape \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\). (default: None)

  • mask (torch.Tensor, optional) – A mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\), which indicates valid nodes for each graph when using node features of shape \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\). (default: None)

  • max_num_nodes (int, optional) – The size of the \(B\) node dimension. Automatically calculated if not given. (default: None)

  • batch_size (int, optional) – The number of examples \(B\). Automatically calculated if not given. (default: None)