Source code for torch_geometric.data.dataset

import copy
import os.path as osp
import re
import sys
import warnings
from collections.abc import Sequence
from typing import (
    Any,
    Callable,
    Iterable,
    Iterator,
    List,
    Optional,
    Tuple,
    Union,
)

import numpy as np
import torch.utils.data
from torch import Tensor

from torch_geometric.data.data import BaseData
from torch_geometric.io import fs

IndexType = Union[slice, Tensor, np.ndarray, Sequence]
MISSING = '???'


[docs]class Dataset(torch.utils.data.Dataset): r"""Dataset base class for creating graph datasets. See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/ create_dataset.html>`__ for the accompanying tutorial. Args: root (str, optional): Root directory where the dataset should be saved. (optional: :obj:`None`) transform (callable, optional): A function/transform that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) log (bool, optional): Whether to print any console output while downloading and processing the dataset. (default: :obj:`True`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ @property def raw_file_names(self) -> Union[str, List[str], Tuple[str, ...]]: r"""The name of the files in the :obj:`self.raw_dir` folder that must be present in order to skip downloading. """ raise NotImplementedError @property def processed_file_names(self) -> Union[str, List[str], Tuple[str, ...]]: r"""The name of the files in the :obj:`self.processed_dir` folder that must be present in order to skip processing. """ raise NotImplementedError
[docs] def download(self) -> None: r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" raise NotImplementedError
[docs] def process(self) -> None: r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" raise NotImplementedError
[docs] def len(self) -> int: r"""Returns the number of data objects stored in the dataset.""" raise NotImplementedError
[docs] def get(self, idx: int) -> BaseData: r"""Gets the data object at index :obj:`idx`.""" raise NotImplementedError
def __init__( self, root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, log: bool = True, force_reload: bool = False, ) -> None: super().__init__() if isinstance(root, str): root = osp.expanduser(fs.normpath(root)) self.root = root or MISSING self.transform = transform self.pre_transform = pre_transform self.pre_filter = pre_filter self.log = log self._indices: Optional[Sequence] = None self.force_reload = force_reload if self.has_download: self._download() if self.has_process: self._process() def indices(self) -> Sequence: return range(self.len()) if self._indices is None else self._indices @property def raw_dir(self) -> str: return osp.join(self.root, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, 'processed') @property def num_node_features(self) -> int: r"""Returns the number of features per node in the dataset.""" data = self[0] # Do not fill cache for `InMemoryDataset`: if hasattr(self, '_data_list') and self._data_list is not None: self._data_list[0] = None data = data[0] if isinstance(data, tuple) else data if hasattr(data, 'num_node_features'): return data.num_node_features raise AttributeError(f"'{data.__class__.__name__}' object has no " f"attribute 'num_node_features'") @property def num_features(self) -> int: r"""Returns the number of features per node in the dataset. Alias for :py:attr:`~num_node_features`. """ return self.num_node_features @property def num_edge_features(self) -> int: r"""Returns the number of features per edge in the dataset.""" data = self[0] # Do not fill cache for `InMemoryDataset`: if hasattr(self, '_data_list') and self._data_list is not None: self._data_list[0] = None data = data[0] if isinstance(data, tuple) else data if hasattr(data, 'num_edge_features'): return data.num_edge_features raise AttributeError(f"'{data.__class__.__name__}' object has no " f"attribute 'num_edge_features'") def _infer_num_classes(self, y: Optional[Tensor]) -> int: if y is None: return 0 elif y.numel() == y.size(0) and not torch.is_floating_point(y): return int(y.max()) + 1 elif y.numel() == y.size(0) and torch.is_floating_point(y): num_classes = torch.unique(y).numel() if num_classes > 2: warnings.warn("Found floating-point labels while calling " "`dataset.num_classes`. Returning the number of " "unique elements. Please make sure that this " "is expected before proceeding.") return num_classes else: return y.size(-1) @property def num_classes(self) -> int: r"""Returns the number of classes in the dataset.""" # We iterate over the dataset and collect all labels to determine the # maximum number of classes. Importantly, in rare cases, `__getitem__` # may produce a tuple of data objects (e.g., when used in combination # with `RandomLinkSplit`, so we take care of this case here as well: data_list = _get_flattened_data_list([data for data in self]) if 'y' in data_list[0] and isinstance(data_list[0].y, Tensor): y = torch.cat([data.y for data in data_list if 'y' in data], dim=0) else: y = torch.as_tensor([data.y for data in data_list if 'y' in data]) # Do not fill cache for `InMemoryDataset`: if hasattr(self, '_data_list') and self._data_list is not None: self._data_list = self.len() * [None] return self._infer_num_classes(y) @property def raw_paths(self) -> List[str]: r"""The absolute filepaths that must be present in order to skip downloading. """ files = self.raw_file_names # Prevent a common source of error in which `file_names` are not # defined as a property. if isinstance(files, Callable): files = files() return [osp.join(self.raw_dir, f) for f in to_list(files)] @property def processed_paths(self) -> List[str]: r"""The absolute filepaths that must be present in order to skip processing. """ files = self.processed_file_names # Prevent a common source of error in which `file_names` are not # defined as a property. if isinstance(files, Callable): files = files() return [osp.join(self.processed_dir, f) for f in to_list(files)] @property def has_download(self) -> bool: r"""Checks whether the dataset defines a :meth:`download` method.""" return overrides_method(self.__class__, 'download') def _download(self): if files_exist(self.raw_paths): # pragma: no cover return fs.makedirs(self.raw_dir, exist_ok=True) self.download() @property def has_process(self) -> bool: r"""Checks whether the dataset defines a :meth:`process` method.""" return overrides_method(self.__class__, 'process') def _process(self): f = osp.join(self.processed_dir, 'pre_transform.pt') if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): warnings.warn( "The `pre_transform` argument differs from the one used in " "the pre-processed version of this dataset. If you want to " "make use of another pre-processing technique, pass " "`force_reload=True` explicitly to reload the dataset.") f = osp.join(self.processed_dir, 'pre_filter.pt') if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): warnings.warn( "The `pre_filter` argument differs from the one used in " "the pre-processed version of this dataset. If you want to " "make use of another pre-fitering technique, pass " "`force_reload=True` explicitly to reload the dataset.") if not self.force_reload and files_exist(self.processed_paths): return if self.log and 'pytest' not in sys.modules: print('Processing...', file=sys.stderr) fs.makedirs(self.processed_dir, exist_ok=True) self.process() path = osp.join(self.processed_dir, 'pre_transform.pt') fs.torch_save(_repr(self.pre_transform), path) path = osp.join(self.processed_dir, 'pre_filter.pt') fs.torch_save(_repr(self.pre_filter), path) if self.log and 'pytest' not in sys.modules: print('Done!', file=sys.stderr) def __len__(self) -> int: r"""The number of examples in the dataset.""" return len(self.indices()) def __getitem__( self, idx: Union[int, np.integer, IndexType], ) -> Union['Dataset', BaseData]: r"""In case :obj:`idx` is of type integer, will return the data object at index :obj:`idx` (and transforms it in case :obj:`transform` is present). In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool, will return a subset of the dataset at the specified indices. """ if (isinstance(idx, (int, np.integer)) or (isinstance(idx, Tensor) and idx.dim() == 0) or (isinstance(idx, np.ndarray) and np.isscalar(idx))): data = self.get(self.indices()[idx]) data = data if self.transform is None else self.transform(data) return data else: return self.index_select(idx) def __iter__(self) -> Iterator[BaseData]: for i in range(len(self)): yield self[i]
[docs] def index_select(self, idx: IndexType) -> 'Dataset': r"""Creates a subset of the dataset from specified indices :obj:`idx`. Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool. """ indices = self.indices() if isinstance(idx, slice): start, stop, step = idx.start, idx.stop, idx.step # Allow floating-point slicing, e.g., dataset[:0.9] if isinstance(start, float): start = round(start * len(self)) if isinstance(stop, float): stop = round(stop * len(self)) idx = slice(start, stop, step) indices = indices[idx] elif isinstance(idx, Tensor) and idx.dtype == torch.long: return self.index_select(idx.flatten().tolist()) elif isinstance(idx, Tensor) and idx.dtype == torch.bool: idx = idx.flatten().nonzero(as_tuple=False) return self.index_select(idx.flatten().tolist()) elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: return self.index_select(idx.flatten().tolist()) elif isinstance(idx, np.ndarray) and idx.dtype == bool: idx = idx.flatten().nonzero()[0] return self.index_select(idx.flatten().tolist()) elif isinstance(idx, Sequence) and not isinstance(idx, str): indices = [indices[i] for i in idx] else: raise IndexError( f"Only slices (':'), list, tuples, torch.tensor and " f"np.ndarray of dtype long or bool are valid indices (got " f"'{type(idx).__name__}')") dataset = copy.copy(self) dataset._indices = indices return dataset
[docs] def shuffle( self, return_perm: bool = False, ) -> Union['Dataset', Tuple['Dataset', Tensor]]: r"""Randomly shuffles the examples in the dataset. Args: return_perm (bool, optional): If set to :obj:`True`, will also return the random permutation used to shuffle the dataset. (default: :obj:`False`) """ perm = torch.randperm(len(self)) dataset = self.index_select(perm) return (dataset, perm) if return_perm is True else dataset
def __repr__(self) -> str: arg_repr = str(len(self)) if len(self) > 1 else '' return f'{self.__class__.__name__}({arg_repr})'
[docs] def get_summary(self) -> Any: r"""Collects summary statistics for the dataset.""" from torch_geometric.data.summary import Summary return Summary.from_dataset(self)
[docs] def print_summary(self) -> None: r"""Prints summary statistics of the dataset to the console.""" print(str(self.get_summary()))
[docs] def to_datapipe(self) -> Any: r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`. The returned instance can then be used with :pyg:`PyG's` built-in :class:`DataPipes` for baching graphs as follows: .. code-block:: python from torch_geometric.datasets import QM9 dp = QM9(root='./data/QM9/').to_datapipe() dp = dp.batch_graphs(batch_size=2, drop_last=True) for batch in dp: pass See the `PyTorch tutorial <https://pytorch.org/data/main/tutorial.html>`_ for further background on DataPipes. """ from torch_geometric.data.datapipes import DatasetAdapter return DatasetAdapter(self)
def overrides_method(cls, method_name: str) -> bool: from torch_geometric.data import InMemoryDataset if method_name in cls.__dict__: return True out = False for base in cls.__bases__: if base != Dataset and base != InMemoryDataset: out |= overrides_method(base, method_name) return out def to_list(value: Any) -> Sequence: if isinstance(value, Sequence) and not isinstance(value, str): return value else: return [value] def files_exist(files: List[str]) -> bool: # NOTE: We return `False` in case `files` is empty, leading to a # re-processing of files on every instantiation. return len(files) != 0 and all([fs.exists(f) for f in files]) def _repr(obj: Any) -> str: if obj is None: return 'None' return re.sub('(<.*?)\\s.*(>)', r'\1\2', str(obj)) def _get_flattened_data_list(data_list: Iterable[Any]) -> List[BaseData]: outs: List[BaseData] = [] for data in data_list: if isinstance(data, BaseData): outs.append(data) elif isinstance(data, (tuple, list)): outs.extend(_get_flattened_data_list(data)) elif isinstance(data, dict): outs.extend(_get_flattened_data_list(data.values())) return outs