Source code for torch_geometric.loader.random_node_loader

import math
from typing import Union

import torch
from torch import Tensor

from import Data, HeteroData
from import to_homogeneous_edge_index

[docs]class RandomNodeLoader( r"""A data loader that randomly samples nodes within a graph and returns their induced subgraph. .. note:: For an example of using :class:`~torch_geometric.loader.RandomNodeLoader`, see `examples/ <>`_. Args: data ( or The :class:`` or :class:`` graph object. num_parts (int): The number of partitions. **kwargs (optional): Additional arguments of :class:``, such as :obj:`num_workers`. """ def __init__( self, data: Union[Data, HeteroData], num_parts: int, **kwargs, ): = data self.num_parts = num_parts if isinstance(data, HeteroData): edge_index, node_dict, edge_dict = to_homogeneous_edge_index(data) self.node_dict, self.edge_dict = node_dict, edge_dict else: edge_index = data.edge_index self.edge_index = edge_index self.num_nodes = data.num_nodes super().__init__( range(self.num_nodes), batch_size=math.ceil(self.num_nodes / num_parts), collate_fn=self.collate_fn, **kwargs, ) def collate_fn(self, index): if not isinstance(index, Tensor): index = torch.tensor(index) if isinstance(, Data): return elif isinstance(, HeteroData): node_dict = { key: index[(index >= start) & (index < end)] - start for key, (start, end) in self.node_dict.items() } return