Source code for torch_geometric.nn.pool.max_pool

from torch_scatter import scatter
from torch_geometric.data import Batch

from .consecutive import consecutive_cluster
from .pool import pool_edge, pool_batch, pool_pos


def _max_pool_x(cluster, x, size=None):
    return scatter(x, cluster, dim=0, dim_size=size, reduce='max')


[docs]def max_pool_x(cluster, x, batch, size=None): r"""Max-Pools node features according to the clustering defined in :attr:`cluster`. Args: cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. size (int, optional): The maximum number of clusters in a single example. This property is useful to obtain a batch-wise dense representation, *e.g.* for applying FC layers, but should only be used if the size of the maximum number of clusters per example is known in advance. (default: :obj:`None`) :rtype: (:class:`Tensor`, :class:`LongTensor`) if :attr:`size` is :obj:`None`, else :class:`Tensor` """ if size is not None: return _max_pool_x(cluster, x, (batch.max().item() + 1) * size) cluster, perm = consecutive_cluster(cluster) x = _max_pool_x(cluster, x) batch = pool_batch(perm, batch) return x, batch
[docs]def max_pool(cluster, data, transform=None): r"""Pools and coarsens a graph given by the :class:`torch_geometric.data.Data` object according to the clustering defined in :attr:`cluster`. All nodes within the same cluster will be represented as one node. Final node features are defined by the *maximum* features of all nodes within the same cluster, node positions are averaged and edge indices are defined to be the union of the edge indices of all nodes within the same cluster. Args: cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. data (Data): Graph data object. transform (callable, optional): A function/transform that takes in the coarsened and pooled :obj:`torch_geometric.data.Data` object and returns a transformed version. (default: :obj:`None`) :rtype: :class:`torch_geometric.data.Data` """ cluster, perm = consecutive_cluster(cluster) x = _max_pool_x(cluster, data.x) index, attr = pool_edge(cluster, data.edge_index, data.edge_attr) batch = None if data.batch is None else pool_batch(perm, data.batch) pos = None if data.pos is None else pool_pos(cluster, data.pos) data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos) if transform is not None: data = transform(data) return data