torch_geometric.data.InMemoryDataset

class InMemoryDataset(root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, log: bool = True)[source]

Bases: Dataset

Dataset base class for creating graph datasets which easily fit into CPU memory. Inherits from torch_geometric.data.Dataset. See here for the accompanying tutorial.

Parameters
  • root (str, optional) – Root directory where the dataset should be saved. (optional: None)

  • transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access. (default: None)

  • pre_transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: None)

  • pre_filter (callable, optional) – A function that takes in an torch_geometric.data.Data object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: None)

  • log (bool, optional) – Whether to print any console output while downloading and processing the dataset. (default: True)

property raw_file_names: Union[str, List[str], Tuple]

The name of the files in the self.raw_dir folder that must be present in order to skip downloading.

property processed_file_names: Union[str, List[str], Tuple]

The name of the files in the self.processed_dir folder that must be present in order to skip processing.

property num_classes: int

Returns the number of classes in the dataset.

len() int[source]

Returns the number of graphs stored in the dataset.

get(idx: int) Data[source]

Gets the data object at index idx.

static collate(data_list: List[Data]) Tuple[Data, Optional[Dict[str, Tensor]]][source]

Collates a Python list of torch_geometric.data.Data objects to the internal storage format of InMemoryDataset.

copy(idx: Optional[Union[slice, Tensor, ndarray, Sequence]] = None) InMemoryDataset[source]

Performs a deep-copy of the dataset. If idx is not given, will clone the full dataset. Otherwise, will only clone a subset of the dataset from indices idx. Indices can be slices, lists, tuples, and a torch.Tensor or np.ndarray of type long or bool.

download()

Downloads the dataset to the self.raw_dir folder.

get_summary()

Collects summary statistics for the dataset.

property has_download: bool

Checks whether the dataset defines a download() method.

property has_process: bool

Checks whether the dataset defines a process() method.

index_select(idx: Union[slice, Tensor, ndarray, Sequence]) Dataset

Creates a subset of the dataset from specified indices idx. Indices idx can be a slicing object, e.g., [2:5], a list, a tuple, or a torch.Tensor or np.ndarray of type long or bool.

property num_edge_features: int

Returns the number of features per edge in the dataset.

property num_features: int

Returns the number of features per node in the dataset. Alias for num_node_features.

property num_node_features: int

Returns the number of features per node in the dataset.

print_summary()

Prints summary statistics of the dataset to the console.

process()

Processes the dataset to the self.processed_dir folder.

property processed_paths: List[str]

The absolute filepaths that must be present in order to skip processing.

property raw_paths: List[str]

The absolute filepaths that must be present in order to skip downloading.

shuffle(return_perm: bool = False) Union[Dataset, Tuple[Dataset, Tensor]]

Randomly shuffles the examples in the dataset.

Parameters

return_perm (bool, optional) – If set to True, will also return the random permutation used to shuffle the dataset. (default: False)

to_datapipe()

Converts the dataset into a torch.utils.data.DataPipe.

The returned instance can then be used with PyG’s built-in DataPipes for baching graphs as follows:

from torch_geometric.datasets import QM9

dp = QM9(root='./data/QM9/').to_datapipe()
dp = dp.batch_graphs(batch_size=2, drop_last=True)

for batch in dp:
    pass

See the PyTorch tutorial for further background on DataPipes.