Source code for torch_geometric.data.dataset

import copy
import collections
import os.path as osp
import warnings
import re

import torch.utils.data

from .makedirs import makedirs


def to_list(x):
    if not isinstance(x, collections.Iterable) or isinstance(x, str):
        x = [x]
    return x


def files_exist(files):
    return all([osp.exists(f) for f in files])


def __repr__(obj):
    if obj is None:
        return 'None'
    return re.sub('(<.*?)\\s.*(>)', r'\1\2', obj.__repr__())


[docs]class Dataset(torch.utils.data.Dataset): r"""Dataset base class for creating graph datasets. See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/ create_dataset.html>`__ for the accompanying tutorial. Args: root (string, optional): Root directory where the dataset should be saved. (optional: :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ @property def raw_file_names(self): r"""The name of the files to find in the :obj:`self.raw_dir` folder in order to skip the download.""" raise NotImplementedError @property def processed_file_names(self): r"""The name of the files to find in the :obj:`self.processed_dir` folder in order to skip the processing.""" raise NotImplementedError
[docs] def download(self): r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" raise NotImplementedError
[docs] def process(self): r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" raise NotImplementedError
def len(self): raise NotImplementedError
[docs] def get(self, idx): r"""Gets the data object at index :obj:`idx`.""" raise NotImplementedError
def __init__(self, root=None, transform=None, pre_transform=None, pre_filter=None): super(Dataset, self).__init__() if isinstance(root, str): root = osp.expanduser(osp.normpath(root)) self.root = root self.transform = transform self.pre_transform = pre_transform self.pre_filter = pre_filter self.__indices__ = None if 'download' in self.__class__.__dict__.keys(): self._download() if 'process' in self.__class__.__dict__.keys(): self._process() def indices(self): if self.__indices__ is not None: return self.__indices__ else: return range(len(self)) @property def raw_dir(self): return osp.join(self.root, 'raw') @property def processed_dir(self): return osp.join(self.root, 'processed') @property def num_node_features(self): r"""Returns the number of features per node in the dataset.""" return self[0].num_node_features @property def num_features(self): r"""Alias for :py:attr:`~num_node_features`.""" return self.num_node_features @property def num_edge_features(self): r"""Returns the number of features per edge in the dataset.""" return self[0].num_edge_features @property def raw_paths(self): r"""The filepaths to find in order to skip the download.""" files = to_list(self.raw_file_names) return [osp.join(self.raw_dir, f) for f in files] @property def processed_paths(self): r"""The filepaths to find in the :obj:`self.processed_dir` folder in order to skip the processing.""" files = to_list(self.processed_file_names) return [osp.join(self.processed_dir, f) for f in files] def _download(self): if files_exist(self.raw_paths): # pragma: no cover return makedirs(self.raw_dir) self.download() def _process(self): f = osp.join(self.processed_dir, 'pre_transform.pt') if osp.exists(f) and torch.load(f) != __repr__(self.pre_transform): warnings.warn( 'The `pre_transform` argument differs from the one used in ' 'the pre-processed version of this dataset. If you really ' 'want to make use of another pre-processing technique, make ' 'sure to delete `{}` first.'.format(self.processed_dir)) f = osp.join(self.processed_dir, 'pre_filter.pt') if osp.exists(f) and torch.load(f) != __repr__(self.pre_filter): warnings.warn( 'The `pre_filter` argument differs from the one used in the ' 'pre-processed version of this dataset. If you really want to ' 'make use of another pre-fitering technique, make sure to ' 'delete `{}` first.'.format(self.processed_dir)) if files_exist(self.processed_paths): # pragma: no cover return print('Processing...') makedirs(self.processed_dir) self.process() path = osp.join(self.processed_dir, 'pre_transform.pt') torch.save(__repr__(self.pre_transform), path) path = osp.join(self.processed_dir, 'pre_filter.pt') torch.save(__repr__(self.pre_filter), path) print('Done!')
[docs] def __len__(self): r"""The number of examples in the dataset.""" if self.__indices__ is not None: return len(self.__indices__) return self.len()
[docs] def __getitem__(self, idx): r"""Gets the data object at index :obj:`idx` and transforms it (in case a :obj:`self.transform` is given). In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a tuple, a LongTensor or a BoolTensor, will return a subset of the dataset at the specified indices.""" if isinstance(idx, int): data = self.get(self.indices()[idx]) data = data if self.transform is None else self.transform(data) return data else: return self.index_select(idx)
def index_select(self, idx): indices = self.indices() if isinstance(idx, slice): indices = indices[idx] elif torch.is_tensor(idx): if idx.dtype == torch.long: if len(idx.shape) == 0: idx = idx.unsqueeze(0) return self.index_select(idx.tolist()) elif idx.dtype == torch.bool or idx.dtype == torch.uint8: return self.index_select(idx.nonzero().flatten().tolist()) elif isinstance(idx, list) or isinstance(idx, tuple): indices = [indices[i] for i in idx] else: raise IndexError( 'Only integers, slices (`:`), list, tuples, and long or bool ' 'tensors are valid indices (got {}).'.format( type(idx).__name__)) dataset = copy.copy(self) dataset.__indices__ = indices return dataset
[docs] def shuffle(self, return_perm=False): r"""Randomly shuffles the examples in the dataset. Args: return_perm (bool, optional): If set to :obj:`True`, will additionally return the random permutation used to shuffle the dataset. (default: :obj:`False`) """ perm = torch.randperm(len(self)) dataset = self.index_select(perm) return (dataset, perm) if return_perm is True else dataset
def __repr__(self): # pragma: no cover return f'{self.__class__.__name__}({len(self)})'