torch_geometric.data.InMemoryDataset
- class InMemoryDataset(root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, log: bool = True, force_reload: bool = False)[source]
-
Dataset base class for creating graph datasets which easily fit into CPU memory. See here for the accompanying tutorial.
- Parameters:
root (str, optional) – Root directory where the dataset should be saved. (optional:
None
)transform (callable, optional) – A function/transform that takes in a
Data
orHeteroData
object and returns a transformed version. The data object will be transformed before every access. (default:None
)pre_transform (callable, optional) – A function/transform that takes in a
Data
orHeteroData
object and returns a transformed version. The data object will be transformed before being saved to disk. (default:None
)pre_filter (callable, optional) – A function that takes in a
Data
orHeteroData
object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default:None
)log (bool, optional) – Whether to print any console output while downloading and processing the dataset. (default:
True
)force_reload (bool, optional) – Whether to re-process the dataset. (default:
False
)
- property raw_file_names: Union[str, List[str], Tuple]
The name of the files in the
self.raw_dir
folder that must be present in order to skip downloading.
- property processed_file_names: Union[str, List[str], Tuple]
The name of the files in the
self.processed_dir
folder that must be present in order to skip processing.
- classmethod save(data_list: List[BaseData], path: str)[source]
Saves a list of data objects to the file path
path
.
- load(path: str, data_cls: ~typing.Type[~torch_geometric.data.data.BaseData] = <class 'torch_geometric.data.data.Data'>)[source]
Loads the dataset from the file path
path
.
- static collate(data_list: List[BaseData]) Tuple[BaseData, Optional[Dict[str, Tensor]]] [source]
Collates a list of
Data
orHeteroData
objects to the internal storage format ofInMemoryDataset
.
- copy(idx: Optional[Union[slice, Tensor, ndarray, Sequence]] = None) InMemoryDataset [source]
Performs a deep-copy of the dataset. If
idx
is not given, will clone the full dataset. Otherwise, will only clone a subset of the dataset from indicesidx
. Indices can be slices, lists, tuples, and atorch.Tensor
ornp.ndarray
of type long or bool.
- to_on_disk_dataset(root: Optional[str] = None, backend: str = 'sqlite', log: bool = True) OnDiskDataset [source]
Converts the
InMemoryDataset
to aOnDiskDataset
variant. Useful for distributed training and hardware instances with limited amount of shared memory.- root (str, optional): Root directory where the dataset should be saved.
If set to
None
, will save the dataset inroot/on_disk
. Note that it is important to specifyroot
to account for different dataset splits. (optional:None
)- backend (str): The
Database
backend to use. (default:
"sqlite"
)- log (bool, optional): Whether to print any console output while
processing the dataset. (default:
True
)
- to(device: Union[int, str]) InMemoryDataset [source]
Performs device conversion of the whole dataset.
- cpu(*args: str) InMemoryDataset [source]
Moves the dataset to CPU memory.
- cuda(device: Optional[Union[int, str]] = None) InMemoryDataset [source]
Moves the dataset toto CUDA memory.
- download()
Downloads the dataset to the
self.raw_dir
folder.
- get_summary()
Collects summary statistics for the dataset.
- property has_download: bool
Checks whether the dataset defines a
download()
method.
- index_select(idx: Union[slice, Tensor, ndarray, Sequence]) Dataset
Creates a subset of the dataset from specified indices
idx
. Indicesidx
can be a slicing object, e.g.,[2:5]
, a list, a tuple, or atorch.Tensor
ornp.ndarray
of type long or bool.
- property num_features: int
Returns the number of features per node in the dataset. Alias for
num_node_features
.
- print_summary()
Prints summary statistics of the dataset to the console.
- process()
Processes the dataset to the
self.processed_dir
folder.
- property processed_paths: List[str]
The absolute filepaths that must be present in order to skip processing.
- property raw_paths: List[str]
The absolute filepaths that must be present in order to skip downloading.
- shuffle(return_perm: bool = False) Union[Dataset, Tuple[Dataset, Tensor]]
Randomly shuffles the examples in the dataset.
- to_datapipe()
Converts the dataset into a
torch.utils.data.DataPipe
.The returned instance can then be used with PyG’s built-in
DataPipes
for baching graphs as follows:from torch_geometric.datasets import QM9 dp = QM9(root='./data/QM9/').to_datapipe() dp = dp.batch_graphs(batch_size=2, drop_last=True) for batch in dp: pass
See the PyTorch tutorial for further background on DataPipes.