Creating Graph Datasets

Although already contains a lot of useful datasets, you may wish to create your own dataset with self-recorded or non-publicly available data.

Implementing datasets by yourself is straightforward and you may want to take a look at the source code to find out how the various datasets are implemented. However, we give a brief introduction on what is needed to setup your own dataset.

We provide two abstract classes for datasets: torch_geometric.data.Dataset and torch_geometric.data.InMemoryDataset. torch_geometric.data.InMemoryDataset inherits from torch_geometric.data.Dataset and should be used if the whole dataset fits into CPU memory.

Following the torchvision convention, each dataset gets passed a root folder which indicates where the dataset should be stored. We split up the root folder into two folders: the raw_dir, where the dataset gets downloaded to, and the processed_dir, where the processed dataset is being saved.

In addition, each dataset can be passed a transform, a pre_transform and a pre_filter function, which are None by default. The transform function dynamically transforms the data object before accessing (so it is best used for data augmentation). The pre_transform function applies the transformation before saving the data objects to disk (so it is best used for heavy precomputation which needs to be only done once). The pre_filter function can manually filter out data objects before saving. Use cases may involve the restriction of data objects being of a specific class.

Creating “In Memory Datasets”

In order to create a torch_geometric.data.InMemoryDataset, you need to implement four fundamental methods:

You can find helpful methods to download and extract data in torch_geometric.data.

The real magic happens in the body of process(). Here, we need to read and create a list of Data objects and save it into the processed_dir. Because saving a huge python list is quite slow, we collate the list into one huge Data object via torch_geometric.data.InMemoryDataset.collate() before saving. The collated data object concatenates all examples into one big data object and, in addition, returns a slices dictionary to reconstruct single examples from this object. Finally, we need to load these two objects in the constructor into the properties self.data and self.slices.

Note

From PyG >= 2.4, the functionalities of torch.save() and torch_geometric.data.InMemoryDataset.collate() are unified and implemented behind torch_geometric.data.InMemoryDataset.save(). Additionally, self.data and self.slices are implicitly loaded via torch_geometric.data.InMemoryDataset.load().

Let’s see this process in a simplified example:

import torch
from torch_geometric.data import InMemoryDataset, download_url


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.load(self.processed_paths[0])
        # For PyG<2.4:
        # self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Download to `self.raw_dir`.
        download_url(url, self.raw_dir)
        ...

    def process(self):
        # Read data into huge `Data` list.
        data_list = [...]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        self.save(data_list, self.processed_paths[0])
        # For PyG<2.4:
        # torch.save(self.collate(data_list), self.processed_paths[0])

Creating “Larger” Datasets

For creating datasets which do not fit into memory, the torch_geometric.data.Dataset can be used, which closely follows the concepts of the torchvision datasets. It expects the following methods to be implemented in addition:

Internally, torch_geometric.data.Dataset.__getitem__() gets data objects from torch_geometric.data.Dataset.get() and optionally transforms them according to transform.

Let’s see this process in a simplified example:

import os.path as osp

import torch
from torch_geometric.data import Dataset, download_url


class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.
        path = download_url(url, self.raw_dir)
        ...

    def process(self):
        idx = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            idx += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

Here, each graph data object gets saved individually in process(), and is manually loaded in get().

Frequently Asked Questions

  1. How can I skip the execution of download() and/or process() ?

    You can skip downloading and/or processing by just not overriding the download() and process() methods:

    class MyOwnDataset(Dataset):
        def __init__(self, transform=None, pre_transform=None):
            super().__init__(None, transform, pre_transform)
    
  2. Do I really need to use these dataset interfaces?

    No! Just as in regular , you do not have to use datasets, e.g., when you want to create synthetic data on the fly without saving them explicitly to disk. In this case, simply pass a regular python list holding torch_geometric.data.Data objects and pass them to torch_geometric.loader.DataLoader:

    from torch_geometric.data import Data
    from torch_geometric.loader import DataLoader
    
    data_list = [Data(...), ..., Data(...)]
    loader = DataLoader(data_list, batch_size=32)
    

Exercises

Consider the following InMemoryDataset constructed from a list of Data objects:

class MyDataset(InMemoryDataset):
    def __init__(self, root, data_list, transform=None):
        self.data_list = data_list
        super().__init__(root, transform)
        self.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return 'data.pt'

    def process(self):
        self.save(self.data_list, self.processed_paths[0])
  1. What is the output of self.processed_paths[0]?

  2. What does save() do?