Source code for torch_geometric.loader.prefetch

import warnings
from contextlib import nullcontext
from functools import partial
from typing import Any, Optional

import torch
from import DataLoader

from torch_geometric.typing import WITH_IPEX

class DeviceHelper:
    def __init__(self, device: Optional[torch.device] = None):
        with_cuda = torch.cuda.is_available()
        with_xpu = torch.xpu.is_available() if WITH_IPEX else False

        if device is None:
            if with_cuda:
                device = 'cuda'
            elif with_xpu:
                device = 'xpu'
                device = 'cpu'

        self.device = torch.device(device)
        self.is_gpu = self.device.type in ['cuda', 'xpu']

        if ((self.device.type == 'cuda' and not with_cuda)
                or (self.device.type == 'xpu' and not with_xpu)):
            warnings.warn(f"Requested device '{self.device.type}' is not "
                          f"available, falling back to CPU")
            self.device = torch.device('cpu') = None
        self.stream_context = nullcontext
        self.module = getattr(torch, self.device.type) if self.is_gpu else None

    def maybe_init_stream(self) -> None:
        if self.is_gpu:
   = self.module.Stream()
            self.stream_context = partial(

    def maybe_wait_stream(self) -> None:
        if is not None:

[docs]class PrefetchLoader: r"""A GPU prefetcher class for asynchronously transferring data of a :class:`` from host memory to device memory. Args: loader ( The data loader. device (torch.device, optional): The device to load the data to. (default: :obj:`None`) """ def __init__( self, loader: DataLoader, device: Optional[torch.device] = None, ): self.loader = loader self.device_helper = DeviceHelper(device) def non_blocking_transfer(self, batch: Any) -> Any: if not self.device_helper.is_gpu: return batch if isinstance(batch, (list, tuple)): return [self.non_blocking_transfer(v) for v in batch] if isinstance(batch, dict): return {k: self.non_blocking_transfer(v) for k, v in batch.items()} batch = batch.pin_memory(self.device_helper.device) return, non_blocking=True) def __iter__(self) -> Any: first = True self.device_helper.maybe_init_stream() batch = None for next_batch in self.loader: with self.device_helper.stream_context(): next_batch = self.non_blocking_transfer(next_batch) if not first: yield batch else: first = False self.device_helper.maybe_wait_stream() batch = next_batch yield batch def __len__(self) -> int: return len(self.loader) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.loader})'