Source code for torch_geometric.loader.cluster

import copy
import os
import os.path as osp
import sys
from dataclasses import dataclass
from typing import List, Literal, Optional

import torch
import torch.utils.data
from torch import Tensor

import torch_geometric.typing
from torch_geometric.data import Data
from torch_geometric.index import index2ptr, ptr2index
from torch_geometric.io import fs
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import index_sort, narrow, select, sort_edge_index
from torch_geometric.utils.map import map_index


@dataclass
class Partition:
    indptr: Tensor
    index: Tensor
    partptr: Tensor
    node_perm: Tensor
    edge_perm: Tensor
    sparse_format: Literal['csr', 'csc']


[docs]class ClusterData(torch.utils.data.Dataset): r"""Clusters/partitions a graph data object into multiple subgraphs, as motivated by the `"Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" <https://arxiv.org/abs/1905.07953>`_ paper. .. note:: The underlying METIS algorithm requires undirected graphs as input. Args: data (torch_geometric.data.Data): The graph data object. num_parts (int): The number of partitions. recursive (bool, optional): If set to :obj:`True`, will use multilevel recursive bisection instead of multilevel k-way partitioning. (default: :obj:`False`) save_dir (str, optional): If set, will save the partitioned data to the :obj:`save_dir` directory for faster re-use. (default: :obj:`None`) filename (str, optional): Name of the stored partitioned file. (default: :obj:`None`) log (bool, optional): If set to :obj:`False`, will not log any progress. (default: :obj:`True`) keep_inter_cluster_edges (bool, optional): If set to :obj:`True`, will keep inter-cluster edge connections. (default: :obj:`False`) sparse_format (str, optional): The sparse format to use for computing partitions. (default: :obj:`"csr"`) """ def __init__( self, data, num_parts: int, recursive: bool = False, save_dir: Optional[str] = None, filename: Optional[str] = None, log: bool = True, keep_inter_cluster_edges: bool = False, sparse_format: Literal['csr', 'csc'] = 'csr', ): assert data.edge_index is not None assert sparse_format in ['csr', 'csc'] self.num_parts = num_parts self.recursive = recursive self.keep_inter_cluster_edges = keep_inter_cluster_edges self.sparse_format = sparse_format recursive_str = '_recursive' if recursive else '' root_dir = osp.join(save_dir or '', f'part_{num_parts}{recursive_str}') path = osp.join(root_dir, filename or 'metis.pt') if save_dir is not None and osp.exists(path): self.partition = fs.torch_load(path) else: if log: # pragma: no cover print('Computing METIS partitioning...', file=sys.stderr) cluster = self._metis(data.edge_index, data.num_nodes) self.partition = self._partition(data.edge_index, cluster) if save_dir is not None: os.makedirs(root_dir, exist_ok=True) torch.save(self.partition, path) if log: # pragma: no cover print('Done!', file=sys.stderr) self.data = self._permute_data(data, self.partition) def _metis(self, edge_index: Tensor, num_nodes: int) -> Tensor: # Computes a node-level partition assignment vector via METIS. if self.sparse_format == 'csr': # Calculate CSR representation: row, index = sort_edge_index(edge_index, num_nodes=num_nodes) indptr = index2ptr(row, size=num_nodes) else: # Calculate CSC representation: index, col = sort_edge_index(edge_index, num_nodes=num_nodes, sort_by_row=False) indptr = index2ptr(col, size=num_nodes) # Compute METIS partitioning: cluster: Optional[Tensor] = None if torch_geometric.typing.WITH_TORCH_SPARSE: try: cluster = torch.ops.torch_sparse.partition( indptr.cpu(), index.cpu(), None, self.num_parts, self.recursive, ).to(edge_index.device) except (AttributeError, RuntimeError): pass if cluster is None and torch_geometric.typing.WITH_METIS: cluster = pyg_lib.partition.metis( indptr.cpu(), index.cpu(), self.num_parts, recursive=self.recursive, ).to(edge_index.device) if cluster is None: raise ImportError(f"'{self.__class__.__name__}' requires either " f"'pyg-lib' or 'torch-sparse'") return cluster def _partition(self, edge_index: Tensor, cluster: Tensor) -> Partition: # Computes node-level and edge-level permutations and permutes the edge # connectivity accordingly: # Sort `cluster` and compute boundaries `partptr`: cluster, node_perm = index_sort(cluster, max_value=self.num_parts) partptr = index2ptr(cluster, size=self.num_parts) # Permute `edge_index` based on node permutation: edge_perm = torch.arange(edge_index.size(1), device=edge_index.device) arange = torch.empty_like(node_perm) arange[node_perm] = torch.arange(cluster.numel(), device=cluster.device) edge_index = arange[edge_index] # Compute final CSR representation: (row, col), edge_perm = sort_edge_index( edge_index, edge_attr=edge_perm, num_nodes=cluster.numel(), sort_by_row=self.sparse_format == 'csr', ) if self.sparse_format == 'csr': indptr, index = index2ptr(row, size=cluster.numel()), col else: indptr, index = index2ptr(col, size=cluster.numel()), row return Partition(indptr, index, partptr, node_perm, edge_perm, self.sparse_format) def _permute_data(self, data: Data, partition: Partition) -> Data: # Permute node-level and edge-level attributes according to the # calculated permutations in `Partition`: out = copy.copy(data) for key, value in data.items(): if key == 'edge_index': continue elif data.is_node_attr(key): cat_dim = data.__cat_dim__(key, value) out[key] = select(value, partition.node_perm, dim=cat_dim) elif data.is_edge_attr(key): cat_dim = data.__cat_dim__(key, value) out[key] = select(value, partition.edge_perm, dim=cat_dim) out.edge_index = None return out def __len__(self) -> int: return self.partition.partptr.numel() - 1 def __getitem__(self, idx: int) -> Data: node_start = int(self.partition.partptr[idx]) node_end = int(self.partition.partptr[idx + 1]) node_length = node_end - node_start indptr = self.partition.indptr[node_start:node_end + 1] edge_start = int(indptr[0]) edge_end = int(indptr[-1]) edge_length = edge_end - edge_start indptr = indptr - edge_start if self.sparse_format == 'csr': row = ptr2index(indptr) col = self.partition.index[edge_start:edge_end] if not self.keep_inter_cluster_edges: edge_mask = (col >= node_start) & (col < node_end) row = row[edge_mask] col = col[edge_mask] - node_start else: col = ptr2index(indptr) row = self.partition.index[edge_start:edge_end] if not self.keep_inter_cluster_edges: edge_mask = (row >= node_start) & (row < node_end) col = col[edge_mask] row = row[edge_mask] - node_start out = copy.copy(self.data) for key, value in self.data.items(): if key == 'num_nodes': out.num_nodes = node_length elif self.data.is_node_attr(key): cat_dim = self.data.__cat_dim__(key, value) out[key] = narrow(value, cat_dim, node_start, node_length) elif self.data.is_edge_attr(key): cat_dim = self.data.__cat_dim__(key, value) out[key] = narrow(value, cat_dim, edge_start, edge_length) if not self.keep_inter_cluster_edges: out[key] = out[key][edge_mask] out.edge_index = torch.stack([row, col], dim=0) return out def __repr__(self) -> str: return f'{self.__class__.__name__}({self.num_parts})'
[docs]class ClusterLoader(torch.utils.data.DataLoader): r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" <https://arxiv.org/abs/1905.07953>`_ paper which merges partioned subgraphs and their between-cluster links from a large-scale graph data object to form a mini-batch. .. note:: Use :class:`~torch_geometric.loader.ClusterData` and :class:`~torch_geometric.loader.ClusterLoader` in conjunction to form mini-batches of clusters. For an example of using Cluster-GCN, see `examples/cluster_gcn_reddit.py <https://github.com/pyg-team/ pytorch_geometric/blob/master/examples/cluster_gcn_reddit.py>`_ or `examples/cluster_gcn_ppi.py <https://github.com/pyg-team/ pytorch_geometric/blob/master/examples/cluster_gcn_ppi.py>`_. Args: cluster_data (torch_geometric.loader.ClusterData): The already partioned data object. **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__(self, cluster_data, **kwargs): self.cluster_data = cluster_data iterator = range(len(cluster_data)) super().__init__(iterator, collate_fn=self._collate, **kwargs) def _collate(self, batch: List[int]) -> Data: if not isinstance(batch, torch.Tensor): batch = torch.tensor(batch) global_indptr = self.cluster_data.partition.indptr global_index = self.cluster_data.partition.index # Get all node-level and edge-level start and end indices for the # current mini-batch: node_start = self.cluster_data.partition.partptr[batch] node_end = self.cluster_data.partition.partptr[batch + 1] edge_start = global_indptr[node_start] edge_end = global_indptr[node_end] # Iterate over each partition in the batch and calculate new edge # connectivity. This is done by slicing the corresponding source and # destination indices for each partition and adjusting their indices to # start from zero: rows, cols, nodes, cumsum = [], [], [], 0 for i in range(batch.numel()): nodes.append(torch.arange(node_start[i], node_end[i])) indptr = global_indptr[node_start[i]:node_end[i] + 1] indptr = indptr - edge_start[i] if self.cluster_data.partition.sparse_format == 'csr': row = ptr2index(indptr) + cumsum col = global_index[edge_start[i]:edge_end[i]] else: col = ptr2index(indptr) + cumsum row = global_index[edge_start[i]:edge_end[i]] rows.append(row) cols.append(col) cumsum += indptr.numel() - 1 node = torch.cat(nodes, dim=0) row = torch.cat(rows, dim=0) col = torch.cat(cols, dim=0) # Map `col` vector to valid entries and remove any entries that do not # connect two nodes within the same mini-batch: if self.cluster_data.partition.sparse_format == 'csr': col, edge_mask = map_index(col, node) row = row[edge_mask] else: row, edge_mask = map_index(row, node) col = col[edge_mask] out = copy.copy(self.cluster_data.data) # Slice node-level and edge-level attributes according to its offsets: for key, value in self.cluster_data.data.items(): if key == 'num_nodes': out.num_nodes = cumsum elif self.cluster_data.data.is_node_attr(key): cat_dim = self.cluster_data.data.__cat_dim__(key, value) out[key] = torch.cat([ narrow(out[key], cat_dim, s, e - s) for s, e in zip(node_start, node_end) ], dim=cat_dim) elif self.cluster_data.data.is_edge_attr(key): cat_dim = self.cluster_data.data.__cat_dim__(key, value) value = torch.cat([ narrow(out[key], cat_dim, s, e - s) for s, e in zip(edge_start, edge_end) ], dim=cat_dim) out[key] = select(value, edge_mask, dim=cat_dim) out.edge_index = torch.stack([row, col], dim=0) return out