torch_geometric.data.OnDiskDataset
- class OnDiskDataset(root: str, transform: ~typing.Optional[~typing.Callable] = None, pre_filter: ~typing.Optional[~typing.Callable] = None, backend: str = 'sqlite', schema: ~typing.Union[~typing.Any, ~typing.Dict[str, ~typing.Any], ~typing.Tuple[~typing.Any], ~typing.List[~typing.Any]] = <class 'object'>, log: bool = True)[source]
Bases:
Dataset
Dataset base class for creating large graph datasets which do not easily fit into CPU memory at once by leveraging a
Database
backend for on-disk storage and access of data objects.- Parameters:
root (str) – Root directory where the dataset should be saved.
transform (callable, optional) – A function/transform that takes in a
Data
orHeteroData
object and returns a transformed version. The data object will be transformed before every access. (default:None
)pre_filter (callable, optional) – A function that takes in a
Data
orHeteroData
object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default:None
)backend (str) – The
Database
backend to use (one of"sqlite"
or"rocksdb"
). (default:"sqlite"
)schema (Any or Tuple[Any] or Dict[str, Any], optional) – The schema of the input data. Can take
int
,float
,str
,object
, or a dictionary withdtype
andsize
keys (for specifying tensor data) as input, and can be nested as a tuple or dictionary. Specifying the schema will improve efficiency, since by default the database will use python pickling for serializing and deserializing. If specified to anything different thanobject
, implementations ofOnDiskDataset
need to overrideserialize()
anddeserialize()
methods. (default:object
)log (bool, optional) – Whether to print any console output while downloading and processing the dataset. (default:
True
)
- property processed_file_names: str
The name of the files in the
self.processed_dir
folder that must be present in order to skip processing.
- serialize(data: BaseData) Any [source]
Serializes the
Data
orHeteroData
object into the expected DB schema.
- deserialize(data: Any) BaseData [source]
Deserializes the DB entry into a
Data
orHeteroData
object.
- extend(data_list: Sequence[BaseData], batch_size: Optional[int] = None) None [source]
Extends the dataset by a list of data objects.
- multi_get(indices: Union[Iterable[int], Tensor, slice, range], batch_size: Optional[int] = None) List[BaseData] [source]
Gets a list of data objects from the specified indices.
- property has_download: bool
Checks whether the dataset defines a
download()
method.
- index_select(idx: Union[slice, Tensor, ndarray, Sequence]) Dataset
Creates a subset of the dataset from specified indices
idx
. Indicesidx
can be a slicing object, e.g.,[2:5]
, a list, a tuple, or atorch.Tensor
ornp.ndarray
of type long or bool.
- property num_features: int
Returns the number of features per node in the dataset. Alias for
num_node_features
.
- property processed_paths: List[str]
The absolute filepaths that must be present in order to skip processing.
- property raw_file_names: Union[str, List[str], Tuple[str, ...]]
The name of the files in the
self.raw_dir
folder that must be present in order to skip downloading.
- property raw_paths: List[str]
The absolute filepaths that must be present in order to skip downloading.
- shuffle(return_perm: bool = False) Union[Dataset, Tuple[Dataset, Tensor]]
Randomly shuffles the examples in the dataset.
- to_datapipe() Any
Converts the dataset into a
torch.utils.data.DataPipe
.The returned instance can then be used with PyG’s built-in
DataPipes
for baching graphs as follows:from torch_geometric.datasets import QM9 dp = QM9(root='./data/QM9/').to_datapipe() dp = dp.batch_graphs(batch_size=2, drop_last=True) for batch in dp: pass
See the PyTorch tutorial for further background on DataPipes.