Source code for torch_geometric.data.on_disk_dataset

import os
from typing import Any, Callable, Iterable, List, Optional, Sequence, Union

from torch import Tensor

from torch_geometric.data import Database, RocksDatabase, SQLiteDatabase
from torch_geometric.data.data import BaseData
from torch_geometric.data.database import Schema
from torch_geometric.data.dataset import Dataset


[docs]class OnDiskDataset(Dataset): r"""Dataset base class for creating large graph datasets which do not easily fit into CPU memory at once by leveraging a :class:`Database` backend for on-disk storage and access of data objects. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) backend (str): The :class:`Database` backend to use (one of :obj:`"sqlite"` or :obj:`"rocksdb"`). (default: :obj:`"sqlite"`) schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of the input data. Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a dictionary with :obj:`dtype` and :obj:`size` keys (for specifying tensor data) as input, and can be nested as a tuple or dictionary. Specifying the schema will improve efficiency, since by default the database will use python pickling for serializing and deserializing. If specified to anything different than :obj:`object`, implementations of :class:`OnDiskDataset` need to override :meth:`serialize` and :meth:`deserialize` methods. (default: :obj:`object`) log (bool, optional): Whether to print any console output while downloading and processing the dataset. (default: :obj:`True`) """ BACKENDS = { 'sqlite': SQLiteDatabase, 'rocksdb': RocksDatabase, } def __init__( self, root: str, transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, backend: str = 'sqlite', schema: Schema = object, log: bool = True, ) -> None: if backend not in self.BACKENDS: raise ValueError(f"Database backend must be one of " f"{set(self.BACKENDS.keys())} " f"(got '{backend}')") self.backend = backend self.schema = schema self._db: Optional[Database] = None self._numel: Optional[int] = None super().__init__(root, transform, pre_filter=pre_filter, log=log) @property def processed_file_names(self) -> str: return f'{self.backend}.db' @property def db(self) -> Database: r"""Returns the underlying :class:`Database`.""" if self._db is not None: return self._db kwargs = {} cls = self.BACKENDS[self.backend] if issubclass(cls, SQLiteDatabase): kwargs['name'] = self.__class__.__name__ os.makedirs(self.processed_dir, exist_ok=True) path = self.processed_paths[0] self._db = cls(path=path, schema=self.schema, **kwargs) self._numel = len(self._db) return self._db
[docs] def close(self) -> None: r"""Closes the connection to the underlying database.""" if self._db is not None: self._db.close()
[docs] def serialize(self, data: BaseData) -> Any: r"""Serializes the :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object into the expected DB schema. """ if self.schema == object: return data raise NotImplementedError(f"`{self.__class__.__name__}.serialize()` " f"needs to be overridden in case a " f"non-default schema was passed")
[docs] def deserialize(self, data: Any) -> BaseData: r"""Deserializes the DB entry into a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object. """ if self.schema == object: return data raise NotImplementedError(f"`{self.__class__.__name__}.deserialize()` " f"needs to be overridden in case a " f"non-default schema was passed")
[docs] def append(self, data: BaseData) -> None: r"""Appends the data object to the dataset.""" index = len(self) self.db.insert(index, self.serialize(data)) self._numel += 1
[docs] def extend( self, data_list: Sequence[BaseData], batch_size: Optional[int] = None, ) -> None: r"""Extends the dataset by a list of data objects.""" start = len(self) end = start + len(data_list) data_list = [self.serialize(data) for data in data_list] self.db.multi_insert(range(start, end), data_list, batch_size) self._numel += (end - start)
[docs] def get(self, idx: int) -> BaseData: r"""Gets the data object at index :obj:`idx`.""" return self.deserialize(self.db.get(idx))
[docs] def multi_get( self, indices: Union[Iterable[int], Tensor, slice, range], batch_size: Optional[int] = None, ) -> List[BaseData]: r"""Gets a list of data objects from the specified indices.""" if len(indices) == 1: data_list = [self.db.get(indices[0])] else: data_list = self.db.multi_get(indices, batch_size) data_list = [self.deserialize(data) for data in data_list] if self.transform is not None: data_list = [self.transform(data) for data in data_list] return data_list
def __getitems__(self, indices: List[int]) -> List[BaseData]: return self.multi_get(indices)
[docs] def len(self) -> int: if self._numel is None: self._numel = len(self.db) return self._numel
def __repr__(self) -> str: return f'{self.__class__.__name__}({len(self)})'