Source code for torch_geometric.data.in_memory_dataset

import copy
import warnings
from abc import ABC
from collections.abc import Mapping, Sequence
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

import torch
from torch import Tensor

from torch_geometric.data import Batch, Data
from torch_geometric.data.collate import collate
from torch_geometric.data.data import BaseData
from torch_geometric.data.dataset import Dataset, IndexType
from torch_geometric.data.separate import separate


[docs]class InMemoryDataset(Dataset, ABC): r"""Dataset base class for creating graph datasets which easily fit into CPU memory. Inherits from :class:`torch_geometric.data.Dataset`. See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/ create_dataset.html#creating-in-memory-datasets>`__ for the accompanying tutorial. Args: root (str, optional): Root directory where the dataset should be saved. (optional: :obj:`None`) transform (callable, optional): A function/transform that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) log (bool, optional): Whether to print any console output while downloading and processing the dataset. (default: :obj:`True`) """ @property def raw_file_names(self) -> Union[str, List[str], Tuple]: raise NotImplementedError @property def processed_file_names(self) -> Union[str, List[str], Tuple]: raise NotImplementedError def __init__( self, root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, log: bool = True, ): super().__init__(root, transform, pre_transform, pre_filter, log) self._data = None self.slices = None self._data_list: Optional[List[BaseData]] = None @property def num_classes(self) -> int: if self.transform is None: return self._infer_num_classes(self._data.y) return super().num_classes
[docs] def len(self) -> int: if self.slices is None: return 1 for _, value in nested_iter(self.slices): return len(value) - 1 return 0
[docs] def get(self, idx: int) -> BaseData: # TODO (matthias) Avoid unnecessary copy here. if self.len() == 1: return copy.copy(self._data) if not hasattr(self, '_data_list') or self._data_list is None: self._data_list = self.len() * [None] elif self._data_list[idx] is not None: return copy.copy(self._data_list[idx]) data = separate( cls=self._data.__class__, batch=self._data, idx=idx, slice_dict=self.slices, decrement=False, ) self._data_list[idx] = copy.copy(data) return data
[docs] @classmethod def save(cls, data_list: List[BaseData], path: str): r"""Saves a list of data objects to the file path :obj:`path`.""" data, slices = cls.collate(data_list) torch.save((data.to_dict(), slices), path)
[docs] def load(self, path: str, data_cls: Type[BaseData] = Data): r"""Loads the dataset from the file path :obj:`path`.""" data, self.slices = torch.load(path) if isinstance(data, dict): # Backward compatibility. data = data_cls.from_dict(data) self.data = data
[docs] @staticmethod def collate( data_list: List[BaseData], ) -> Tuple[BaseData, Optional[Dict[str, Tensor]]]: r"""Collates a Python list of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects to the internal storage format of :class:`~torch_geometric.data.InMemoryDataset`.""" if len(data_list) == 1: return data_list[0], None data, slices, _ = collate( data_list[0].__class__, data_list=data_list, increment=False, add_batch=False, ) return data, slices
[docs] def copy(self, idx: Optional[IndexType] = None) -> 'InMemoryDataset': r"""Performs a deep-copy of the dataset. If :obj:`idx` is not given, will clone the full dataset. Otherwise, will only clone a subset of the dataset from indices :obj:`idx`. Indices can be slices, lists, tuples, and a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool. """ if idx is None: data_list = [self.get(i) for i in self.indices()] else: data_list = [self.get(i) for i in self.index_select(idx).indices()] dataset = copy.copy(self) dataset._indices = None dataset._data_list = None dataset.data, dataset.slices = self.collate(data_list) return dataset
@property def data(self) -> Any: msg1 = ("It is not recommended to directly access the internal " "storage format `data` of an 'InMemoryDataset'.") msg2 = ("The given 'InMemoryDataset' only references a subset of " "examples of the full dataset, but 'data' will contain " "information of the full dataset.") msg3 = ("The data of the dataset is already cached, so any " "modifications to `data` will not be reflected when accessing " "its elements. Clearing the cache now by removing all " "elements in `dataset._data_list`.") msg4 = ("If you are absolutely certain what you are doing, access the " "internal storage via `InMemoryDataset._data` instead to " "suppress this warning. Alternatively, you can access stacked " "individual attributes of every graph via " "`dataset.{attr_name}`.") msg = msg1 if self._indices is not None: msg += f' {msg2}' if self._data_list is not None: msg += f' {msg3}' self._data_list = None msg += f' {msg4}' warnings.warn(msg) return self._data @data.setter def data(self, value: Any): self._data = value self._data_list = None def __getattr__(self, key: str) -> Any: data = self.__dict__.get('_data') if isinstance(data, Data) and key in data: if self._indices is None and data.__inc__(key, data[key]) == 0: return data[key] else: data_list = [self.get(i) for i in self.indices()] return Batch.from_data_list(data_list)[key] raise AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{key}'")
def nested_iter(node: Union[Mapping, Sequence]) -> Iterable: if isinstance(node, Mapping): for key, value in node.items(): for inner_key, inner_value in nested_iter(value): yield inner_key, inner_value elif isinstance(node, Sequence): for i, inner_value in enumerate(node): yield i, inner_value else: yield None, node