import os
import os.path as osp
import pickle
import logging
import torch
from torch_geometric.data import (InMemoryDataset, download_url, extract_zip,
Data)
from torch_geometric.utils import remove_self_loops
[docs]class GNNBenchmarkDataset(InMemoryDataset):
r"""A variety of artificially and semi-artificially generated graph
datasets from the `"Benchmarking Graph Neural Networks"
<https://arxiv.org/abs/2003.00982>`_ paper.
.. note::
The ZINC dataset is provided via
:class:`torch_geometric.datasets.ZINC`.
Args:
root (string): Root directory where the dataset should be saved.
name (string): The name of the dataset (one of :obj:`"PATTERN"`,
:obj:`"CLUSTER"`, :obj:`"MNIST"`, :obj:`"CIFAR10"`,
:obj:`"TSP"`, :obj:`"CSL"`)
split (string, optional): If :obj:`"train"`, loads the training
dataset.
If :obj:`"val"`, loads the validation dataset.
If :obj:`"test"`, loads the test dataset.
(default: :obj:`"train"`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""
names = ['PATTERN', 'CLUSTER', 'MNIST', 'CIFAR10', 'TSP', 'CSL']
url = 'https://pytorch-geometric.com/datasets/benchmarking-gnns'
csl_url = 'https://www.dropbox.com/s/rnbkp5ubgk82ocu/CSL.zip?dl=1'
def __init__(self, root, name, split='train', transform=None,
pre_transform=None, pre_filter=None, use_node_attr=False,
use_edge_attr=False, cleaned=False):
self.name = name
assert self.name in self.names
if self.name == 'CSL' and split != 'train':
split = 'train'
logging.warning(
('Dataset `CSL` does not provide a standardized splitting. '
'Instead, it is recommended to perform 5-fold cross '
'validation with stratifed sampling.'))
super(GNNBenchmarkDataset, self).__init__(root, transform,
pre_transform, pre_filter)
if split == 'train':
path = self.processed_paths[0]
elif split == 'val':
path = self.processed_paths[1]
elif split == 'test':
path = self.processed_paths[2]
else:
raise ValueError((f'Split {split} found, but expected either '
'train, val, trainval or test'))
self.data, self.slices = torch.load(path)
@property
def raw_dir(self):
return osp.join(self.root, self.name, 'raw')
@property
def processed_dir(self):
return osp.join(self.root, self.name, 'processed')
@property
def raw_file_names(self):
name = self.name
if name == 'CSL':
return [
'graphs_Kary_Deterministic_Graphs.pkl',
'y_Kary_Deterministic_Graphs.pt'
]
else:
return [f'{name}_train.pt', f'{name}_val.pt', f'{name}_test.pt']
@property
def processed_file_names(self):
if self.name == 'CSL':
return ['data.pt']
else:
return ['train_data.pt', 'val_data.pt', 'test_data.pt']
def download(self):
url = self.csl_url
url = f'{self.url}/{self.name}.zip' if self.name != 'CSL' else url
path = download_url(url, self.raw_dir)
extract_zip(path, self.raw_dir)
os.unlink(path)
def process(self):
if self.name == 'CSL':
data_list = self.process_CSL()
torch.save(self.collate(data_list), self.processed_paths[0])
else:
for i in range(3):
self.data, self.slices = torch.load(self.raw_paths[i])
data_list = [self.get(i) for i in range(len(self))]
if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]
if self.pre_transform is not None:
data_list = [self.pre_transform(d) for d in data_list]
torch.save(self.collate(data_list), self.processed_paths[i])
def process_CSL(self):
path = osp.join(self.raw_dir, 'graphs_Kary_Deterministic_Graphs.pkl')
with open(path, 'rb') as f:
adjs = pickle.load(f)
path = osp.join(self.raw_dir, 'y_Kary_Deterministic_Graphs.pt')
ys = torch.load(path).tolist()
data_list = []
for adj, y in zip(adjs, ys):
row, col = torch.from_numpy(adj.row), torch.from_numpy(adj.col)
edge_index = torch.stack([row, col], dim=0).to(torch.long)
edge_index, _ = remove_self_loops(edge_index)
data = Data(edge_index=edge_index, y=y, num_nodes=adj.shape[0])
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)
return data_list
def __repr__(self):
return '{}({})'.format(self.name, len(self))