# torch_geometric.nn.pool.MemPooling

class MemPooling(in_channels: int, out_channels: int, heads: int, num_clusters: int, tau: float = 1.0)[source]

Bases: Module

Memory based pooling layer from “Memory-Based Graph Networks” paper, which learns a coarsened graph representation based on soft cluster assignments.

\begin{align}\begin{aligned}S_{i,j}^{(h)} &= \frac{ (1+{\| \mathbf{x}_i-\mathbf{k}^{(h)}_j \|}^2 / \tau)^{ -\frac{1+\tau}{2}}}{ \sum_{k=1}^K (1 + {\| \mathbf{x}_i-\mathbf{k}^{(h)}_k \|}^2 / \tau)^{ -\frac{1+\tau}{2}}}\\\mathbf{S} &= \textrm{softmax}(\textrm{Conv2d} (\Vert_{h=1}^H \mathbf{S}^{(h)})) \in \mathbb{R}^{N \times K}\\\mathbf{X}^{\prime} &= \mathbf{S}^{\top} \mathbf{X} \mathbf{W} \in \mathbb{R}^{K \times F^{\prime}}\end{aligned}\end{align}

where $$H$$ denotes the number of heads, and $$K$$ denotes the number of clusters.

Parameters:
• in_channels (int) – Size of each input sample $$F$$.

• out_channels (int) – Size of each output sample $$F^{\prime}$$.

• heads (int) – The number of heads $$H$$.

• num_clusters (int) – number of clusters $$K$$ per head.

• tau (int, optional) – The temperature $$\tau$$. (default: 1.)

reset_parameters()[source]

Resets all learnable parameters of the module.

static kl_loss(S: Tensor) [source]

\begin{align}\begin{aligned}P_{i,j} &= \frac{S_{i,j}^2 / \sum_{n=1}^N S_{n,j}}{\sum_{k=1}^K S_{i,k}^2 / \sum_{n=1}^N S_{n,k}}\\\mathcal{L}_{\textrm{KL}} &= \textrm{KLDiv}(\mathbf{P} \Vert \mathbf{S})\end{aligned}\end{align}
forward(x: Tensor, batch: = None, mask: = None, max_num_nodes: = None, batch_size: = None) [source]

Forward pass.

Parameters:
• x (torch.Tensor) – The node feature tensor of shape $$\mathbf{X} \in \mathbb{R}^{N \times F}$$ or $$\mathbf{X} \in \mathbb{R}^{B \times N \times F}$$.

• batch (torch.Tensor, optional) – The batch vector $$\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N$$, which assigns each node to a specific example. Should not be provided in case node features already have shape $$\mathbf{X} \in \mathbb{R}^{B \times N \times F}$$. (default: None)

• mask (torch.Tensor, optional) – A mask matrix $$\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}$$, which indicates valid nodes for each graph when using node features of shape $$\mathbf{X} \in \mathbb{R}^{B \times N \times F}$$. (default: None)

• max_num_nodes (int, optional) – The size of the $$B$$ node dimension. Automatically calculated if not given. (default: None)

• batch_size (int, optional) – The number of examples $$B$$. Automatically calculated if not given. (default: None)