Source code for torch_geometric.transforms.largest_connected_components

import torch

from import Data
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_scipy_sparse_matrix

[docs]class LargestConnectedComponents(BaseTransform): r"""Selects the subgraph that corresponds to the largest connected components in the graph. Args: num_components (int, optional): Number of largest components to keep (default: :obj:`1`) """ def __init__(self, num_components: int = 1): self.num_components = num_components def __call__(self, data: Data) -> Data: import numpy as np import scipy.sparse as sp adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=data.num_nodes) num_components, component = sp.csgraph.connected_components(adj) if num_components <= self.num_components: return data _, count = np.unique(component, return_counts=True) subset = np.in1d(component, count.argsort()[-self.num_components:]) return data.subgraph(torch.from_numpy(subset).to(torch.bool)) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.num_components})'