Source code for torch_geometric.data.dataset

import collections
import os.path as osp
import warnings
import re

import torch.utils.data

from .makedirs import makedirs


def to_list(x):
    if not isinstance(x, collections.Iterable) or isinstance(x, str):
        x = [x]
    return x


def files_exist(files):
    return all([osp.exists(f) for f in files])


def __repr__(obj):
    if obj is None:
        return 'None'
    return re.sub('(<.*?)\\s.*(>)', r'\1\2', obj.__repr__())


[docs]class Dataset(torch.utils.data.Dataset): r"""Dataset base class for creating graph datasets. See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/ create_dataset.html>`__ for the accompanying tutorial. Args: root (string): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ @property def raw_file_names(self): r"""The name of the files to find in the :obj:`self.raw_dir` folder in order to skip the download.""" raise NotImplementedError @property def processed_file_names(self): r"""The name of the files to find in the :obj:`self.processed_dir` folder in order to skip the processing.""" raise NotImplementedError
[docs] def download(self): r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" raise NotImplementedError
[docs] def process(self): r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" raise NotImplementedError
[docs] def __len__(self): r"""The number of examples in the dataset.""" raise NotImplementedError
[docs] def get(self, idx): r"""Gets the data object at index :obj:`idx`.""" raise NotImplementedError
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): super(Dataset, self).__init__() self.root = osp.expanduser(osp.normpath(root)) self.raw_dir = osp.join(self.root, 'raw') self.processed_dir = osp.join(self.root, 'processed') self.transform = transform self.pre_transform = pre_transform self.pre_filter = pre_filter path = osp.join(self.processed_dir, 'pre_transform.pt') if osp.exists(path) and torch.load(path) != __repr__(pre_transform): warnings.warn( 'The `pre_transform` argument differs from the one used in ' 'the pre-processed version of this dataset. If you really ' 'want to make use of another pre-processing technique, make ' 'sure to delete `{}/processed` first.'.format( self.processed_dir)) path = osp.join(self.processed_dir, 'pre_filter.pt') if osp.exists(path) and torch.load(path) != __repr__(pre_filter): warnings.warn( 'The `pre_filter` argument differs from the one used in the ' 'pre-processed version of this dataset. If you really want to ' 'make use of another pre-fitering technique, make sure to ' 'delete `{}/processed` first.'.format(self.processed_dir)) self._download() self._process() @property def num_node_features(self): r"""Returns the number of features per node in the dataset.""" return self[0].num_node_features @property def num_features(self): r"""Alias for :py:attr:`~num_node_features`.""" return self.num_node_features @property def num_edge_features(self): r"""Returns the number of features per edge in the dataset.""" return self[0].num_edge_features @property def raw_paths(self): r"""The filepaths to find in order to skip the download.""" files = to_list(self.raw_file_names) return [osp.join(self.raw_dir, f) for f in files] @property def processed_paths(self): r"""The filepaths to find in the :obj:`self.processed_dir` folder in order to skip the processing.""" files = to_list(self.processed_file_names) return [osp.join(self.processed_dir, f) for f in files] def _download(self): if files_exist(self.raw_paths): # pragma: no cover return makedirs(self.raw_dir) self.download() def _process(self): if files_exist(self.processed_paths): # pragma: no cover return print('Processing...') makedirs(self.processed_dir) self.process() path = osp.join(self.processed_dir, 'pre_transform.pt') torch.save(__repr__(self.pre_transform), path) path = osp.join(self.processed_dir, 'pre_filter.pt') torch.save(__repr__(self.pre_filter), path) print('Done!')
[docs] def __getitem__(self, idx): # pragma: no cover r"""Gets the data object at index :obj:`idx` and transforms it (in case a :obj:`self.transform` is given).""" data = self.get(idx) data = data if self.transform is None else self.transform(data) return data
def __repr__(self): # pragma: no cover return '{}({})'.format(self.__class__.__name__, len(self))