import os
import os.path as osp
import pickle
from typing import Callable, List, Optional
import torch
from tqdm import tqdm
from torch_geometric.data import (
Data,
InMemoryDataset,
download_url,
extract_zip,
)
from torch_geometric.io import fs
[docs]class ZINC(InMemoryDataset):
r"""The ZINC dataset from the `ZINC database
<https://pubs.acs.org/doi/abs/10.1021/acs.jcim.5b00559>`_ and the
`"Automatic Chemical Design Using a Data-Driven Continuous Representation
of Molecules" <https://arxiv.org/abs/1610.02415>`_ paper, containing about
250,000 molecular graphs with up to 38 heavy atoms.
The task is to regress the penalized :obj:`logP` (also called constrained
solubility in some works), given by :obj:`y = logP - SAS - cycles`, where
:obj:`logP` is the water-octanol partition coefficient, :obj:`SAS` is the
synthetic accessibility score, and :obj:`cycles` denotes the number of
cycles with more than six atoms.
Penalized :obj:`logP` is a score commonly used for training molecular
generation models, see, *e.g.*, the
`"Junction Tree Variational Autoencoder for Molecular Graph Generation"
<https://proceedings.mlr.press/v80/jin18a.html>`_ and
`"Grammar Variational Autoencoder"
<https://proceedings.mlr.press/v70/kusner17a.html>`_ papers.
Args:
root (str): Root directory where the dataset should be saved.
subset (bool, optional): If set to :obj:`True`, will only load a
subset of the dataset (12,000 molecular graphs), following the
`"Benchmarking Graph Neural Networks"
<https://arxiv.org/abs/2003.00982>`_ paper. (default: :obj:`False`)
split (str, optional): If :obj:`"train"`, loads the training dataset.
If :obj:`"val"`, loads the validation dataset.
If :obj:`"test"`, loads the test dataset.
(default: :obj:`"train"`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
.. list-table::
:widths: 20 10 10 10 10 10
:header-rows: 1
* - Name
- #graphs
- #nodes
- #edges
- #features
- #classes
* - ZINC Full
- 249,456
- ~23.2
- ~49.8
- 1
- 1
* - ZINC Subset
- 12,000
- ~23.2
- ~49.8
- 1
- 1
"""
url = 'https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1'
split_url = ('https://raw.githubusercontent.com/graphdeeplearning/'
'benchmarking-gnns/master/data/molecules/{}.index')
def __init__(
self,
root: str,
subset: bool = False,
split: str = 'train',
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
force_reload: bool = False,
) -> None:
self.subset = subset
assert split in ['train', 'val', 'test']
super().__init__(root, transform, pre_transform, pre_filter,
force_reload=force_reload)
path = osp.join(self.processed_dir, f'{split}.pt')
self.load(path)
@property
def raw_file_names(self) -> List[str]:
return [
'train.pickle', 'val.pickle', 'test.pickle', 'train.index',
'val.index', 'test.index'
]
@property
def processed_dir(self) -> str:
name = 'subset' if self.subset else 'full'
return osp.join(self.root, name, 'processed')
@property
def processed_file_names(self) -> List[str]:
return ['train.pt', 'val.pt', 'test.pt']
def download(self) -> None:
fs.rm(self.raw_dir)
path = download_url(self.url, self.root)
extract_zip(path, self.root)
os.rename(osp.join(self.root, 'molecules'), self.raw_dir)
os.unlink(path)
for split in ['train', 'val', 'test']:
download_url(self.split_url.format(split), self.raw_dir)
def process(self) -> None:
for split in ['train', 'val', 'test']:
with open(osp.join(self.raw_dir, f'{split}.pickle'), 'rb') as f:
mols = pickle.load(f)
indices = list(range(len(mols)))
if self.subset:
with open(osp.join(self.raw_dir, f'{split}.index')) as f:
indices = [int(x) for x in f.read()[:-1].split(',')]
pbar = tqdm(total=len(indices))
pbar.set_description(f'Processing {split} dataset')
data_list = []
for idx in indices:
mol = mols[idx]
x = mol['atom_type'].to(torch.long).view(-1, 1)
y = mol['logP_SA_cycle_normalized'].to(torch.float)
adj = mol['bond_type']
edge_index = adj.nonzero(as_tuple=False).t().contiguous()
edge_attr = adj[edge_index[0], edge_index[1]].to(torch.long)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
y=y)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)
pbar.update(1)
pbar.close()
self.save(data_list, osp.join(self.processed_dir, f'{split}.pt'))