Source code for torch_geometric.datasets.gnn_benchmark_dataset

from typing import Optional, Callable, List

import os
import os.path as osp
import pickle
import logging

import torch
from torch_geometric.data import (InMemoryDataset, download_url, extract_zip,
                                  Data)
from torch_geometric.utils import remove_self_loops


[docs]class GNNBenchmarkDataset(InMemoryDataset): r"""A variety of artificially and semi-artificially generated graph datasets from the `"Benchmarking Graph Neural Networks" <https://arxiv.org/abs/2003.00982>`_ paper. .. note:: The ZINC dataset is provided via :class:`torch_geometric.datasets.ZINC`. Args: root (string): Root directory where the dataset should be saved. name (string): The name of the dataset (one of :obj:`"PATTERN"`, :obj:`"CLUSTER"`, :obj:`"MNIST"`, :obj:`"CIFAR10"`, :obj:`"TSP"`, :obj:`"CSL"`) split (string, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ names = ['PATTERN', 'CLUSTER', 'MNIST', 'CIFAR10', 'TSP', 'CSL'] url = 'https://pytorch-geometric.com/datasets/benchmarking-gnns' csl_url = 'https://www.dropbox.com/s/rnbkp5ubgk82ocu/CSL.zip?dl=1' def __init__(self, root: str, name: str, split: str = "train", transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None): self.name = name assert self.name in self.names if self.name == 'CSL' and split != 'train': split = 'train' logging.warning( ('Dataset `CSL` does not provide a standardized splitting. ' 'Instead, it is recommended to perform 5-fold cross ' 'validation with stratifed sampling.')) super().__init__(root, transform, pre_transform, pre_filter) if split == 'train': path = self.processed_paths[0] elif split == 'val': path = self.processed_paths[1] elif split == 'test': path = self.processed_paths[2] else: raise ValueError(f"Split '{split}' found, but expected either " f"'train', 'val', or 'test'") self.data, self.slices = torch.load(path) @property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> List[str]: name = self.name if name == 'CSL': return [ 'graphs_Kary_Deterministic_Graphs.pkl', 'y_Kary_Deterministic_Graphs.pt' ] else: return [f'{name}_train.pt', f'{name}_val.pt', f'{name}_test.pt'] @property def processed_file_names(self) -> List[str]: if self.name == 'CSL': return ['data.pt'] else: return ['train_data.pt', 'val_data.pt', 'test_data.pt'] def download(self): url = self.csl_url url = f'{self.url}/{self.name}.zip' if self.name != 'CSL' else url path = download_url(url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process(self): if self.name == 'CSL': data_list = self.process_CSL() torch.save(self.collate(data_list), self.processed_paths[0]) else: for i in range(3): self.data, self.slices = torch.load(self.raw_paths[i]) data_list = [self.get(i) for i in range(len(self))] if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] torch.save(self.collate(data_list), self.processed_paths[i]) def process_CSL(self) -> List[Data]: path = osp.join(self.raw_dir, 'graphs_Kary_Deterministic_Graphs.pkl') with open(path, 'rb') as f: adjs = pickle.load(f) path = osp.join(self.raw_dir, 'y_Kary_Deterministic_Graphs.pt') ys = torch.load(path).tolist() data_list = [] for adj, y in zip(adjs, ys): row, col = torch.from_numpy(adj.row), torch.from_numpy(adj.col) edge_index = torch.stack([row, col], dim=0).to(torch.long) edge_index, _ = remove_self_loops(edge_index) data = Data(edge_index=edge_index, y=y, num_nodes=adj.shape[0]) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) return data_list def __repr__(self) -> str: return f'{self.name}({len(self)})'