Source code for torch_geometric.nn.pool.max_pool

from typing import Optional

from torch_scatter import scatter
from torch_geometric.data import Batch
from torch_geometric.utils import add_self_loops

from .consecutive import consecutive_cluster
from .pool import pool_edge, pool_batch, pool_pos


def _max_pool_x(cluster, x, size: Optional[int] = None):
    return scatter(x, cluster, dim=0, dim_size=size, reduce='max')


[docs]def max_pool_x(cluster, x, batch, size: Optional[int] = None): r"""Max-Pools node features according to the clustering defined in :attr:`cluster`. Args: cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. size (int, optional): The maximum number of clusters in a single example. This property is useful to obtain a batch-wise dense representation, *e.g.* for applying FC layers, but should only be used if the size of the maximum number of clusters per example is known in advance. (default: :obj:`None`) :rtype: (:class:`Tensor`, :class:`LongTensor`) if :attr:`size` is :obj:`None`, else :class:`Tensor` """ if size is not None: batch_size = int(batch.max().item()) + 1 return _max_pool_x(cluster, x, batch_size * size), None cluster, perm = consecutive_cluster(cluster) x = _max_pool_x(cluster, x) batch = pool_batch(perm, batch) return x, batch
[docs]def max_pool(cluster, data, transform=None): r"""Pools and coarsens a graph given by the :class:`torch_geometric.data.Data` object according to the clustering defined in :attr:`cluster`. All nodes within the same cluster will be represented as one node. Final node features are defined by the *maximum* features of all nodes within the same cluster, node positions are averaged and edge indices are defined to be the union of the edge indices of all nodes within the same cluster. Args: cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. data (Data): Graph data object. transform (callable, optional): A function/transform that takes in the coarsened and pooled :obj:`torch_geometric.data.Data` object and returns a transformed version. (default: :obj:`None`) :rtype: :class:`torch_geometric.data.Data` """ cluster, perm = consecutive_cluster(cluster) x = None if data.x is None else _max_pool_x(cluster, data.x) index, attr = pool_edge(cluster, data.edge_index, data.edge_attr) batch = None if data.batch is None else pool_batch(perm, data.batch) pos = None if data.pos is None else pool_pos(cluster, data.pos) data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos) if transform is not None: data = transform(data) return data
[docs]def max_pool_neighbor_x(data, flow='source_to_target'): r"""Max pools neighboring node features, where each feature in :obj:`data.x` is replaced by the feature value with the maximum value from the central node and its neighbors. """ x, edge_index = data.x, data.edge_index edge_index, _ = add_self_loops(edge_index, num_nodes=data.num_nodes) row, col = edge_index row, col = (row, col) if flow == 'source_to_target' else (col, row) data.x = scatter(x[row], col, dim=0, dim_size=data.num_nodes, reduce='max') return data