Source code for torch_geometric.loader.zip_loader

from typing import Any, Iterator, List, Optional, Tuple, Union

import torch
from torch import Tensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import LinkLoader, NodeLoader
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.utils import infer_filter_per_worker


[docs]class ZipLoader(torch.utils.data.DataLoader): r"""A loader that returns a tuple of data objects by sampling from multiple :class:`NodeLoader` or :class:`LinkLoader` instances. Args: loaders (List[NodeLoader] or List[LinkLoader]): The loader instances. filter_per_worker (bool, optional): If set to :obj:`True`, will filter the returned data in each worker's subprocess. If set to :obj:`False`, will filter the returned data in the main process. If set to :obj:`None`, will automatically infer the decision based on whether data partially lives on the GPU (:obj:`filter_per_worker=True`) or entirely on the CPU (:obj:`filter_per_worker=False`). There exists different trade-offs for setting this option. Specifically, setting this option to :obj:`True` for in-memory datasets will move all features to shared memory, which may result in too many open file handles. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, loaders: Union[List[NodeLoader], List[LinkLoader]], filter_per_worker: Optional[bool] = None, **kwargs, ): if filter_per_worker is None: filter_per_worker = infer_filter_per_worker(loaders[0].data) # Remove for PyTorch Lightning: kwargs.pop('dataset', None) kwargs.pop('collate_fn', None) for loader in loaders: if not callable(getattr(loader, 'collate_fn', None)): raise ValueError("'{loader.__class__.__name__}' does not have " "a 'collate_fn' method") if not callable(getattr(loader, 'filter_fn', None)): raise ValueError("'{loader.__class__.__name__}' does not have " "a 'filter_fn' method") loader.filter_per_worker = filter_per_worker iterator = range(min([len(loader.dataset) for loader in loaders])) super().__init__(iterator, collate_fn=self.collate_fn, **kwargs) self.loaders = loaders self.filter_per_worker = filter_per_worker def __call__( self, index: Union[Tensor, List[int]], ) -> Union[Tuple[Data, ...], Tuple[HeteroData, ...]]: r"""Samples subgraphs from a batch of input IDs.""" out = self.collate_fn(index) if not self.filter_per_worker: out = self.filter_fn(out) return out def collate_fn(self, index: List[int]) -> Tuple[Any, ...]: if not isinstance(index, Tensor): index = torch.tensor(index, dtype=torch.long) return tuple(loader.collate_fn(index) for loader in self.loaders) def filter_fn( self, outs: Tuple[Any, ...], ) -> Tuple[Union[Data, HeteroData], ...]: loaders = self.loaders return tuple(loader.filter_fn(v) for loader, v in zip(loaders, outs)) def _get_iterator(self) -> Iterator: if self.filter_per_worker: return super()._get_iterator() # Execute `filter_fn` in the main process: return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __repr__(self) -> str: return f'{self.__class__.__name__}(loaders={self.loaders})'