torch_geometric.data.OnDiskDataset

class OnDiskDataset(root: str, transform: ~typing.Optional[~typing.Callable] = None, pre_filter: ~typing.Optional[~typing.Callable] = None, backend: str = 'sqlite', schema: ~typing.Union[~typing.Any, ~typing.Dict[str, ~typing.Any], ~typing.Tuple[~typing.Any], ~typing.List[~typing.Any]] = <class 'object'>, log: bool = True)[source]

Bases: Dataset

Dataset base class for creating large graph datasets which do not easily fit into CPU memory at once by leveraging a Database backend for on-disk storage and access of data objects.

Parameters:
  • root (str) – Root directory where the dataset should be saved.

  • transform (callable, optional) – A function/transform that takes in a Data or HeteroData object and returns a transformed version. The data object will be transformed before every access. (default: None)

  • pre_filter (callable, optional) – A function that takes in a Data or HeteroData object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: None)

  • backend (str) – The Database backend to use (one of "sqlite" or "rocksdb"). (default: "sqlite")

  • schema (Any or Tuple[Any] or Dict[str, Any], optional) – The schema of the input data. Can take int, float, str, object, or a dictionary with dtype and size keys (for specifying tensor data) as input, and can be nested as a tuple or dictionary. Specifying the schema will improve efficiency, since by default the database will use python pickling for serializing and deserializing. If specified to anything different than object, implementations of OnDiskDataset need to override serialize() and deserialize() methods. (default: object)

  • log (bool, optional) – Whether to print any console output while downloading and processing the dataset. (default: True)

property processed_file_names: str

The name of the files in the self.processed_dir folder that must be present in order to skip processing.

property db: Database

Returns the underlying Database.

close() None[source]

Closes the connection to the underlying database.

serialize(data: BaseData) Any[source]

Serializes the Data or HeteroData object into the expected DB schema.

deserialize(data: Any) BaseData[source]

Deserializes the DB entry into a Data or HeteroData object.

append(data: BaseData) None[source]

Appends the data object to the dataset.

extend(data_list: Sequence[BaseData], batch_size: Optional[int] = None) None[source]

Extends the dataset by a list of data objects.

get(idx: int) BaseData[source]

Gets the data object at index idx.

multi_get(indices: Union[Iterable[int], Tensor, slice, range], batch_size: Optional[int] = None) List[BaseData][source]

Gets a list of data objects from the specified indices.

len() int[source]

Returns the number of data objects stored in the dataset.

download() None

Downloads the dataset to the self.raw_dir folder.

get_summary() Any

Collects summary statistics for the dataset.

property has_download: bool

Checks whether the dataset defines a download() method.

property has_process: bool

Checks whether the dataset defines a process() method.

index_select(idx: Union[slice, Tensor, ndarray, Sequence]) Dataset

Creates a subset of the dataset from specified indices idx. Indices idx can be a slicing object, e.g., [2:5], a list, a tuple, or a torch.Tensor or np.ndarray of type long or bool.

property num_classes: int

Returns the number of classes in the dataset.

property num_edge_features: int

Returns the number of features per edge in the dataset.

property num_features: int

Returns the number of features per node in the dataset. Alias for num_node_features.

property num_node_features: int

Returns the number of features per node in the dataset.

print_summary() None

Prints summary statistics of the dataset to the console.

process() None

Processes the dataset to the self.processed_dir folder.

property processed_paths: List[str]

The absolute filepaths that must be present in order to skip processing.

property raw_file_names: Union[str, List[str], Tuple[str, ...]]

The name of the files in the self.raw_dir folder that must be present in order to skip downloading.

property raw_paths: List[str]

The absolute filepaths that must be present in order to skip downloading.

shuffle(return_perm: bool = False) Union[Dataset, Tuple[Dataset, Tensor]]

Randomly shuffles the examples in the dataset.

Parameters:

return_perm (bool, optional) – If set to True, will also return the random permutation used to shuffle the dataset. (default: False)

to_datapipe() Any

Converts the dataset into a torch.utils.data.DataPipe.

The returned instance can then be used with built-in DataPipes for baching graphs as follows:

from torch_geometric.datasets import QM9

dp = QM9(root='./data/QM9/').to_datapipe()
dp = dp.batch_graphs(batch_size=2, drop_last=True)

for batch in dp:
    pass

See the PyTorch tutorial for further background on DataPipes.