Source code for torch_geometric.data.in_memory_dataset

import copy
import os.path as osp
import warnings
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Mapping,
    MutableSequence,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
)

import torch
from torch import Tensor
from tqdm import tqdm

import torch_geometric
from torch_geometric.data import Batch, Data
from torch_geometric.data.collate import collate
from torch_geometric.data.data import BaseData
from torch_geometric.data.dataset import Dataset, IndexType
from torch_geometric.data.separate import separate
from torch_geometric.io import fs


[docs]class InMemoryDataset(Dataset): r"""Dataset base class for creating graph datasets which easily fit into CPU memory. See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/ create_dataset.html#creating-in-memory-datasets>`__ for the accompanying tutorial. Args: root (str, optional): Root directory where the dataset should be saved. (optional: :obj:`None`) transform (callable, optional): A function/transform that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) log (bool, optional): Whether to print any console output while downloading and processing the dataset. (default: :obj:`True`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ @property def raw_file_names(self) -> Union[str, List[str], Tuple[str, ...]]: raise NotImplementedError @property def processed_file_names(self) -> Union[str, List[str], Tuple[str, ...]]: raise NotImplementedError def __init__( self, root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, log: bool = True, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, log, force_reload) self._data: Optional[BaseData] = None self.slices: Optional[Dict[str, Tensor]] = None self._data_list: Optional[MutableSequence[Optional[BaseData]]] = None @property def num_classes(self) -> int: if self.transform is None: return self._infer_num_classes(self._data.y) return super().num_classes
[docs] def len(self) -> int: if self.slices is None: return 1 for _, value in nested_iter(self.slices): return len(value) - 1 return 0
[docs] def get(self, idx: int) -> BaseData: # TODO (matthias) Avoid unnecessary copy here. if self.len() == 1: return copy.copy(self._data) if not hasattr(self, '_data_list') or self._data_list is None: self._data_list = self.len() * [None] elif self._data_list[idx] is not None: return copy.copy(self._data_list[idx]) data = separate( cls=self._data.__class__, batch=self._data, idx=idx, slice_dict=self.slices, decrement=False, ) self._data_list[idx] = copy.copy(data) return data
[docs] @classmethod def save(cls, data_list: Sequence[BaseData], path: str) -> None: r"""Saves a list of data objects to the file path :obj:`path`.""" data, slices = cls.collate(data_list) fs.torch_save((data.to_dict(), slices, data.__class__), path)
[docs] def load(self, path: str, data_cls: Type[BaseData] = Data) -> None: r"""Loads the dataset from the file path :obj:`path`.""" out = fs.torch_load(path) assert isinstance(out, tuple) assert len(out) == 2 or len(out) == 3 if len(out) == 2: # Backward compatibility. data, self.slices = out else: data, self.slices, data_cls = out if not isinstance(data, dict): # Backward compatibility. self.data = data else: self.data = data_cls.from_dict(data)
[docs] @staticmethod def collate( data_list: Sequence[BaseData], ) -> Tuple[BaseData, Optional[Dict[str, Tensor]]]: r"""Collates a list of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects to the internal storage format of :class:`~torch_geometric.data.InMemoryDataset`. """ if len(data_list) == 1: return data_list[0], None data, slices, _ = collate( data_list[0].__class__, data_list=data_list, increment=False, add_batch=False, ) return data, slices
[docs] def copy(self, idx: Optional[IndexType] = None) -> 'InMemoryDataset': r"""Performs a deep-copy of the dataset. If :obj:`idx` is not given, will clone the full dataset. Otherwise, will only clone a subset of the dataset from indices :obj:`idx`. Indices can be slices, lists, tuples, and a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool. """ if idx is None: data_list = [self.get(i) for i in self.indices()] else: data_list = [self.get(i) for i in self.index_select(idx).indices()] dataset = copy.copy(self) dataset._indices = None dataset._data_list = None dataset.data, dataset.slices = self.collate(data_list) return dataset
[docs] def to_on_disk_dataset( self, root: Optional[str] = None, backend: str = 'sqlite', log: bool = True, ) -> 'torch_geometric.data.OnDiskDataset': r"""Converts the :class:`InMemoryDataset` to a :class:`OnDiskDataset` variant. Useful for distributed training and hardware instances with limited amount of shared memory. root (str, optional): Root directory where the dataset should be saved. If set to :obj:`None`, will save the dataset in :obj:`root/on_disk`. Note that it is important to specify :obj:`root` to account for different dataset splits. (optional: :obj:`None`) backend (str): The :class:`Database` backend to use. (default: :obj:`"sqlite"`) log (bool, optional): Whether to print any console output while processing the dataset. (default: :obj:`True`) """ if root is None and (self.root is None or not osp.exists(self.root)): raise ValueError(f"The root directory of " f"'{self.__class__.__name__}' is not specified. " f"Please pass in 'root' when creating on-disk " f"datasets from it.") root = root or osp.join(self.root, 'on_disk') in_memory_dataset = self ref_data = in_memory_dataset.get(0) if not isinstance(ref_data, Data): raise NotImplementedError( f"`{self.__class__.__name__}.to_on_disk_dataset()` is " f"currently only supported on homogeneous graphs") # Parse the schema ==================================================== schema: Dict[str, Any] = {} for key, value in ref_data.to_dict().items(): if isinstance(value, (int, float, str)): schema[key] = value.__class__ elif isinstance(value, Tensor) and value.dim() == 0: schema[key] = dict(dtype=value.dtype, size=(-1, )) elif isinstance(value, Tensor): size = list(value.size()) size[ref_data.__cat_dim__(key, value)] = -1 schema[key] = dict(dtype=value.dtype, size=tuple(size)) else: schema[key] = object # Create the on-disk dataset ========================================== class OnDiskDataset(torch_geometric.data.OnDiskDataset): def __init__( self, root: str, transform: Optional[Callable] = None, ): super().__init__( root=root, transform=transform, backend=backend, schema=schema, ) def process(self): _iter = [ in_memory_dataset.get(i) for i in in_memory_dataset.indices() ] if log: # pragma: no cover _iter = tqdm(_iter, desc='Converting to OnDiskDataset') data_list: List[Data] = [] for i, data in enumerate(_iter): data_list.append(data) if i + 1 == len(in_memory_dataset) or (i + 1) % 1000 == 0: self.extend(data_list) data_list = [] def serialize(self, data: Data) -> Dict[str, Any]: return data.to_dict() def deserialize(self, data: Dict[str, Any]) -> Data: return Data.from_dict(data) def __repr__(self) -> str: arg_repr = str(len(self)) if len(self) > 1 else '' return (f'OnDisk{in_memory_dataset.__class__.__name__}(' f'{arg_repr})') return OnDiskDataset(root, transform=in_memory_dataset.transform)
@property def data(self) -> Any: msg1 = ("It is not recommended to directly access the internal " "storage format `data` of an 'InMemoryDataset'.") msg2 = ("The given 'InMemoryDataset' only references a subset of " "examples of the full dataset, but 'data' will contain " "information of the full dataset.") msg3 = ("The data of the dataset is already cached, so any " "modifications to `data` will not be reflected when accessing " "its elements. Clearing the cache now by removing all " "elements in `dataset._data_list`.") msg4 = ("If you are absolutely certain what you are doing, access the " "internal storage via `InMemoryDataset._data` instead to " "suppress this warning. Alternatively, you can access stacked " "individual attributes of every graph via " "`dataset.{attr_name}`.") msg = msg1 if self._indices is not None: msg += f' {msg2}' if self._data_list is not None: msg += f' {msg3}' self._data_list = None msg += f' {msg4}' warnings.warn(msg) return self._data @data.setter def data(self, value: Any): self._data = value self._data_list = None def __getattr__(self, key: str) -> Any: data = self.__dict__.get('_data') if isinstance(data, Data) and key in data: if self._indices is None and data.__inc__(key, data[key]) == 0: return data[key] else: data_list = [self.get(i) for i in self.indices()] return Batch.from_data_list(data_list)[key] raise AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{key}'")
[docs] def to(self, device: Union[int, str]) -> 'InMemoryDataset': r"""Performs device conversion of the whole dataset.""" if self._indices is not None: raise ValueError("The given 'InMemoryDataset' only references a " "subset of examples of the full dataset") if self._data_list is not None: raise ValueError("The data of the dataset is already cached") self._data.to(device) return self
[docs] def cpu(self, *args: str) -> 'InMemoryDataset': r"""Moves the dataset to CPU memory.""" return self.to(torch.device('cpu'))
[docs] def cuda( self, device: Optional[Union[int, str]] = None, ) -> 'InMemoryDataset': r"""Moves the dataset toto CUDA memory.""" if isinstance(device, int): device = f'cuda:{int}' elif device is None: device = 'cuda' return self.to(device)
def nested_iter(node: Union[Mapping, Sequence]) -> Iterable: if isinstance(node, Mapping): for key, value in node.items(): for inner_key, inner_value in nested_iter(value): yield inner_key, inner_value elif isinstance(node, Sequence): for i, inner_value in enumerate(node): yield i, inner_value else: yield None, node