Source code for torch_geometric.loader.node_loader

import logging
from contextlib import contextmanager
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union

import psutil
import torch

from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
from torch_geometric.loader.base import DataLoaderIterator, WorkerInitWrapper
from torch_geometric.loader.utils import (
    filter_custom_store,
    filter_data,
    filter_hetero_data,
    get_input_nodes,
    get_numa_nodes_cores,
)
from torch_geometric.sampler import (
    BaseSampler,
    HeteroSamplerOutput,
    NodeSamplerInput,
    SamplerOutput,
)
from torch_geometric.typing import InputNodes, OptTensor


[docs]class NodeLoader(torch.utils.data.DataLoader): r"""A data loader that performs mini-batch sampling from node information, using a generic :class:`~torch_geometric.sampler.BaseSampler` implementation that defines a :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes` function and is supported on the provided input :obj:`data` object. Args: data (Any): A :class:`~torch_geometric.data.Data`, :class:`~torch_geometric.data.HeteroData`, or (:class:`~torch_geometric.data.FeatureStore`, :class:`~torch_geometric.data.GraphStore`) data object. node_sampler (torch_geometric.sampler.BaseSampler): The sampler implementation to be used with this loader. Needs to implement :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes`. The sampler implementation must be compatible with the input :obj:`data` object. input_nodes (torch.Tensor or str or Tuple[str, torch.Tensor]): The indices of seed nodes to start sampling from. Needs to be either given as a :obj:`torch.LongTensor` or :obj:`torch.BoolTensor`. If set to :obj:`None`, all nodes will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the node type and node indices. (default: :obj:`None`) input_time (torch.Tensor, optional): Optional values to override the timestamp for the input nodes given in :obj:`input_nodes`. If not set, will use the timestamps in :obj:`time_attr` as default (if present). The :obj:`time_attr` needs to be set for this to work. (default: :obj:`None`) transform (Callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) transform_sampler_output (Callable, optional): A function/transform that takes in a :class:`torch_geometric.sampler.SamplerOutput` and returns a transformed version. (default: :obj:`None`) filter_per_worker (bool, optional): If set to :obj:`True`, will filter the returning data in each worker's subprocess rather than in the main process. Setting this to :obj:`True` for in-memory datasets is generally not recommended: (1) it may result in too many open file handles, (2) it may slown down data loading, (3) it requires operating on CPU tensors. (default: :obj:`False`) custom_cls (HeteroData, optional): A custom :class:`~torch_geometric.data.HeteroData` class to return for mini-batches in case of remote backends. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], node_sampler: BaseSampler, input_nodes: InputNodes = None, input_time: OptTensor = None, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, filter_per_worker: bool = False, custom_cls: Optional[HeteroData] = None, input_id: OptTensor = None, **kwargs, ): # Remove for PyTorch Lightning: if 'dataset' in kwargs: del kwargs['dataset'] if 'collate_fn' in kwargs: del kwargs['collate_fn'] # Get node type (or `None` for homogeneous graphs): input_type, input_nodes = get_input_nodes(data, input_nodes) self.data = data self.node_sampler = node_sampler self.transform = transform self.transform_sampler_output = transform_sampler_output self.filter_per_worker = filter_per_worker self.custom_cls = custom_cls self.input_data = NodeSamplerInput( input_id=input_id, node=input_nodes, time=input_time, input_type=input_type, ) # TODO: Unify DL affinitization in `BaseDataLoader` class # CPU Affinitization for loader and compute cores self.num_workers = kwargs.get('num_workers', 0) self.is_cuda_available = torch.cuda.is_available() self.cpu_affinity_enabled = False worker_init_fn = WorkerInitWrapper(kwargs.pop('worker_init_fn', None)) iterator = range(input_nodes.size(0)) super().__init__(iterator, collate_fn=self.collate_fn, worker_init_fn=worker_init_fn, **kwargs)
[docs] def collate_fn(self, index: NodeSamplerInput) -> Any: r"""Samples a subgraph from a batch of input nodes.""" input_data: NodeSamplerInput = self.input_data[index] out = self.node_sampler.sample_from_nodes(input_data) if self.filter_per_worker: # Execute `filter_fn` in the worker process out = self.filter_fn(out) return out
[docs] def filter_fn( self, out: Union[SamplerOutput, HeteroSamplerOutput], ) -> Union[Data, HeteroData]: r"""Joins the sampled nodes with their corresponding features, returning the resulting :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object to be used downstream. """ if self.transform_sampler_output: out = self.transform_sampler_output(out) if isinstance(out, SamplerOutput): data = filter_data(self.data, out.node, out.row, out.col, out.edge, self.node_sampler.edge_permutation) if 'n_id' not in data: data.n_id = out.node if out.edge is not None and 'e_id' not in data: data.e_id = out.edge data.batch = out.batch data.input_id = out.metadata[0] data.seed_time = out.metadata[1] data.batch_size = out.metadata[0].size(0) elif isinstance(out, HeteroSamplerOutput): if isinstance(self.data, HeteroData): data = filter_hetero_data(self.data, out.node, out.row, out.col, out.edge, self.node_sampler.edge_permutation) else: # Tuple[FeatureStore, GraphStore] data = filter_custom_store(*self.data, out.node, out.row, out.col, out.edge, self.custom_cls) for key, node in out.node.items(): if 'n_id' not in data[key]: data[key].n_id = node if out.edge is not None: for key, edge in out.edge.items(): if 'e_id' not in data[key]: data[key].e_id = edge if out.batch is not None: for key, batch in out.batch.items(): data[key].batch = batch input_type = self.input_data.input_type data[input_type].input_id = out.metadata[0] data[input_type].seed_time = out.metadata[1] data[input_type].batch_size = out.metadata[0].size(0) else: raise TypeError(f"'{self.__class__.__name__}'' found invalid " f"type: '{type(out)}'") return data if self.transform is None else self.transform(data)
def _get_iterator(self) -> Iterator: if self.filter_per_worker: return super()._get_iterator() # if not self.is_cuda_available and not self.cpu_affinity_enabled: # TODO: Add manual page for best CPU practices # link = ... # Warning('Dataloader CPU affinity opt is not enabled, consider ' # 'switching it on with enable_cpu_affinity() or see CPU ' # f'best practices for PyG [{link}])') # Execute `filter_fn` in the main process: return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __enter__(self): return self def __repr__(self) -> str: return f'{self.__class__.__name__}()'
[docs] @contextmanager def enable_cpu_affinity(self, loader_cores: Optional[List[int]] = None): r"""A context manager to enable CPU affinity for data loader workers (only used when running on CPU devices). Affinitization places data loader workers threads on specific CPU cores. In effect, it allows for more efficient local memory allocation and reduces remote memory calls. Every time a process or thread moves from one core to another, registers and caches need to be flushed and reloaded. This can become very costly if it happens often, and our threads may also no longer be close to their data, or be able to share data in a cache. .. warning:: If you want to further affinitize compute threads (*i.e.* with OMP), please make sure that you exclude :obj:`loader_cores` from the list of cores available for compute. This will cause core oversubsription and exacerbate performance. .. code-block:: python loader = NeigborLoader(data, num_workers=3) with loader.enable_cpu_affinity(loader_cores=[0, 1, 2]): for batch in loader: pass This will be gradually extended to increase performance on dual socket CPUs. Args: loader_cores ([int], optional): List of CPU cores to which data loader workers should affinitize to. By default, :obj:`cpu0` is reserved for all auxiliary threads and ops. The :class:`DataLoader` wil affinitize to cores starting at :obj:`cpu0`. (default: :obj:`node0_cores[:num_workers]`) """ if not self.is_cuda_available: if not self.num_workers > 0: raise ValueError( f"'enable_cpu_affinity' should be used with at least one " f"worker (got {self.num_workers})") if loader_cores and len(loader_cores) != self.num_workers: raise ValueError( f"The number of loader cores (got {len(loader_cores)}) " f"in 'enable_cpu_affinity' should match with the number " f"of workers (got {self.num_workers})") worker_init_fn_old = self.worker_init_fn affinity_old = psutil.Process().cpu_affinity() nthreads_old = torch.get_num_threads() loader_cores = loader_cores[:] if loader_cores else None def init_fn(worker_id): try: psutil.Process().cpu_affinity([loader_cores[worker_id]]) except IndexError: raise ValueError(f"Cannot use CPU affinity for worker ID " f"{worker_id} on CPU {loader_cores}") worker_init_fn_old(worker_id) if loader_cores is None: numa_info = get_numa_nodes_cores() if numa_info and len(numa_info[0]) > self.num_workers: # Take one thread per each node 0 core: node0_cores = [cpus[0] for core_id, cpus in numa_info[0]] else: node0_cores = list(range(psutil.cpu_count(logical=False))) if len(node0_cores) < self.num_workers: raise ValueError( f"More workers (got {self.num_workers}) than " f"available cores (got {len(node0_cores)})") # Set default loader core IDs: loader_cores = node0_cores[:self.num_workers] try: # Set CPU affinity for dataloader: self.worker_init_fn = init_fn self.cpu_affinity_enabled = True logging.info(f'{self.num_workers} data loader workers ' f'are assigned to CPUs {loader_cores}') yield finally: # Restore omp_num_threads and cpu affinity: psutil.Process().cpu_affinity(affinity_old) torch.set_num_threads(nthreads_old) self.worker_init_fn = worker_init_fn_old self.cpu_affinity_enabled = False else: yield