Source code for torch_geometric.data.in_memory_dataset

from typing import Optional, Callable, List, Union, Tuple, Dict

import copy
from itertools import repeat, product

import torch
from torch import Tensor

from torch_geometric.data.data import Data
from torch_geometric.data.dataset import Dataset, IndexType


[docs]class InMemoryDataset(Dataset): r"""Dataset base class for creating graph datasets which fit completely into CPU memory. See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/ create_dataset.html#creating-in-memory-datasets>`__ for the accompanying tutorial. Args: root (string, optional): Root directory where the dataset should be saved. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ @property def raw_file_names(self) -> Union[str, List[str], Tuple]: r"""The name of the files to find in the :obj:`self.raw_dir` folder in order to skip the download.""" raise NotImplementedError @property def processed_file_names(self) -> Union[str, List[str], Tuple]: r"""The name of the files to find in the :obj:`self.processed_dir` folder in order to skip the processing.""" raise NotImplementedError
[docs] def download(self): r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" raise NotImplementedError
[docs] def process(self): r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" raise NotImplementedError
def __init__(self, root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None): super().__init__(root, transform, pre_transform, pre_filter) self.data = None self.slices = None self._data_list: Optional[List[Data]] = None @property def num_classes(self) -> int: r"""The number of classes in the dataset.""" y = self.data.y if y is None: return 0 elif y.numel() == y.size(0) and not torch.is_floating_point(y): return int(self.data.y.max()) + 1 elif y.numel() == y.size(0) and torch.is_floating_point(y): return torch.unique(y).numel() else: return self.data.y.size(-1) def len(self) -> int: for item in self.slices.values(): return len(item) - 1 return 0
[docs] def get(self, idx: int) -> Data: if hasattr(self, '_data_list'): if self._data_list is None: self._data_list = self.len() * [None] else: data = self._data_list[idx] if data is not None: return copy.copy(data) data = self.data.__class__() if hasattr(self.data, '__num_nodes__'): data.num_nodes = self.data.__num_nodes__[idx] for key in self.data.keys: item, slices = self.data[key], self.slices[key] start, end = slices[idx].item(), slices[idx + 1].item() if torch.is_tensor(item): s = list(repeat(slice(None), item.dim())) cat_dim = self.data.__cat_dim__(key, item) if cat_dim is None: cat_dim = 0 s[cat_dim] = slice(start, end) elif start + 1 == end: s = slices[start] else: s = slice(start, end) data[key] = item[s] if hasattr(self, '_data_list'): self._data_list[idx] = copy.copy(data) return data
[docs] @staticmethod def collate(data_list: List[Data]) -> Tuple[Data, Dict[str, Tensor]]: r"""Collates a python list of data objects to the internal storage format of :class:`torch_geometric.data.InMemoryDataset`.""" keys = data_list[0].keys data = data_list[0].__class__() for key in keys: data[key] = [] slices = {key: [0] for key in keys} for item, key in product(data_list, keys): data[key].append(item[key]) if isinstance(item[key], Tensor) and item[key].dim() > 0: cat_dim = item.__cat_dim__(key, item[key]) cat_dim = 0 if cat_dim is None else cat_dim s = slices[key][-1] + item[key].size(cat_dim) else: s = slices[key][-1] + 1 slices[key].append(s) if hasattr(data_list[0], '__num_nodes__'): data.__num_nodes__ = [] for item in data_list: data.__num_nodes__.append(item.num_nodes) for key in keys: item = data_list[0][key] if isinstance(item, Tensor) and len(data_list) > 1: if item.dim() > 0: cat_dim = data.__cat_dim__(key, item) cat_dim = 0 if cat_dim is None else cat_dim data[key] = torch.cat(data[key], dim=cat_dim) else: data[key] = torch.stack(data[key]) elif isinstance(item, Tensor): # Don't duplicate attributes... data[key] = data[key][0] elif isinstance(item, int) or isinstance(item, float): data[key] = torch.tensor(data[key]) slices[key] = torch.tensor(slices[key], dtype=torch.long) return data, slices
def copy(self, idx: Optional[IndexType] = None): if idx is None: data_list = [self.get(i) for i in range(len(self))] else: data_list = [self.get(i) for i in self.index_select(idx).indices()] dataset = copy.copy(self) dataset._indices = None dataset._data_list = data_list dataset.data, dataset.slices = self.collate(data_list) return dataset