Source code for

import os
from typing import Any, Callable, Iterable, List, Optional, Sequence, Union

from torch import Tensor

from import Database, RocksDatabase, SQLiteDatabase
from import BaseData
from import Schema
from import Dataset

[docs]class OnDiskDataset(Dataset): r"""Dataset base class for creating large graph datasets which do not easily fit into CPU memory at once by leveraging a :class:`Database` backend for on-disk storage and access of data objects. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in a :class:`` or :class:`` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in a :class:`` or :class:`` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) backend (str): The :class:`Database` backend to use (one of :obj:`"sqlite"` or :obj:`"rocksdb"`). (default: :obj:`"sqlite"`) schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of the input data. Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a dictionary with :obj:`dtype` and :obj:`size` keys (for specifying tensor data) as input, and can be nested as a tuple or dictionary. Specifying the schema will improve efficiency, since by default the database will use python pickling for serializing and deserializing. If specified to anything different than :obj:`object`, implementations of :class:`OnDiskDataset` need to override :meth:`serialize` and :meth:`deserialize` methods. (default: :obj:`object`) log (bool, optional): Whether to print any console output while downloading and processing the dataset. (default: :obj:`True`) """ BACKENDS = { 'sqlite': SQLiteDatabase, 'rocksdb': RocksDatabase, } def __init__( self, root: str, transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, backend: str = 'sqlite', schema: Schema = object, log: bool = True, ) -> None: if backend not in self.BACKENDS: raise ValueError(f"Database backend must be one of " f"{set(self.BACKENDS.keys())} " f"(got '{backend}')") self.backend = backend self.schema = schema self._db: Optional[Database] = None self._numel: Optional[int] = None super().__init__(root, transform, pre_filter=pre_filter, log=log) @property def processed_file_names(self) -> str: return f'{self.backend}.db' @property def db(self) -> Database: r"""Returns the underlying :class:`Database`.""" if self._db is not None: return self._db kwargs = {} cls = self.BACKENDS[self.backend] if issubclass(cls, SQLiteDatabase): kwargs['name'] = self.__class__.__name__ os.makedirs(self.processed_dir, exist_ok=True) path = self.processed_paths[0] self._db = cls(path=path, schema=self.schema, **kwargs) self._numel = len(self._db) return self._db
[docs] def close(self) -> None: r"""Closes the connection to the underlying database.""" if self._db is not None: self._db.close()
[docs] def serialize(self, data: BaseData) -> Any: r"""Serializes the :class:`` or :class:`` object into the expected DB schema. """ if self.schema == object: return data raise NotImplementedError(f"`{self.__class__.__name__}.serialize()` " f"needs to be overridden in case a " f"non-default schema was passed")
[docs] def deserialize(self, data: Any) -> BaseData: r"""Deserializes the DB entry into a :class:`` or :class:`` object. """ if self.schema == object: return data raise NotImplementedError(f"`{self.__class__.__name__}.deserialize()` " f"needs to be overridden in case a " f"non-default schema was passed")
[docs] def append(self, data: BaseData) -> None: r"""Appends the data object to the dataset.""" index = len(self) self.db.insert(index, self.serialize(data)) self._numel += 1
[docs] def extend( self, data_list: Sequence[BaseData], batch_size: Optional[int] = None, ) -> None: r"""Extends the dataset by a list of data objects.""" start = len(self) end = start + len(data_list) data_list = [self.serialize(data) for data in data_list] self.db.multi_insert(range(start, end), data_list, batch_size) self._numel += (end - start)
[docs] def get(self, idx: int) -> BaseData: r"""Gets the data object at index :obj:`idx`.""" return self.deserialize(self.db.get(idx))
[docs] def multi_get( self, indices: Union[Iterable[int], Tensor, slice, range], batch_size: Optional[int] = None, ) -> List[BaseData]: r"""Gets a list of data objects from the specified indices.""" if len(indices) == 1: data_list = [self.db.get(indices[0])] else: data_list = self.db.multi_get(indices, batch_size) data_list = [self.deserialize(data) for data in data_list] if self.transform is not None: data_list = [self.transform(data) for data in data_list] return data_list
def __getitems__(self, indices: List[int]) -> List[BaseData]: return self.multi_get(indices)
[docs] def len(self) -> int: if self._numel is None: self._numel = len(self.db) return self._numel
def __repr__(self) -> str: return f'{self.__class__.__name__}({len(self)})'