Source code for torch_geometric.datasets.md17

import os
import os.path as osp
from typing import Callable, List, Optional

import numpy as np
import torch

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)


[docs]class MD17(InMemoryDataset): r"""A variety of ab-initio molecular dynamics trajectories from the authors of `sGDML <http://quantum-machine.org/gdml>`__. This class provides access to the original MD17 datasets as well as all other datasets released by sGDML since then (15 in total). For every trajectory, the dataset contains the Cartesian positions of atoms (in Angstrom), their atomic numbers, as well as the total energy (in kcal/mol) and forces (kcal/mol/Angstrom) on each atom. The latter two are the regression targets for this collection. .. note:: Data objects contain no edge indices as these are most commonly constructed via the :obj:`torch_geometric.transforms.RadiusGraph` transform, with its cut-off being a hyperparameter. Some of the trajectories were computed at different levels of theory, and for most molecules there exists two versions: a long trajectory on DFT level of theory and a short trajectory on coupled cluster level of theory. Check the table below for detailed information on the molecule, level of theory and number of data points contained in each dataset. Which trajectory is loaded is determined by the :attr:`name` argument. For the coupled cluster trajectories, the dataset comes with pre-defined training and testing splits which are loaded separately via the :attr:`train` argument. When using these datasets, make sure to cite the appropriate publications listed on the `sGDML <http://quantum-machine.org/gdml/>`__ website. +----------------+-----------------+------------------------------+-----------+ | Molecule | Level of Theory | Name | #Examples | +================+=================+==============================+===========+ | Benzene | DFT | :obj:`benzene` | 49,863 | +----------------+-----------------+------------------------------+-----------+ | Benzene | DFT FHI-aims | :obj:`benzene FHI-aims` | 627,983 | +----------------+-----------------+------------------------------+-----------+ | Benzene | CCSD(T) | :obj:`benzene CCSD(T)` | 1,500 | +----------------+-----------------+------------------------------+-----------+ | Uracil | DFT | :obj:`uracil` | 133,770 | +----------------+-----------------+------------------------------+-----------+ | Naphthalene | DFT | :obj:`napthalene` | 326,250 | +----------------+-----------------+------------------------------+-----------+ | Aspirin | DFT | :obj:`aspirin` | 211,762 | +----------------+-----------------+------------------------------+-----------+ | Aspirin | CCSD | :obj:`aspirin CCSD` | 1,500 | +----------------+-----------------+------------------------------+-----------+ | Salicylic acid | DFT | :obj:`salicylic acid` | 320,231 | +----------------+-----------------+------------------------------+-----------+ | Malonaldehyde | DFT | :obj:`malonaldehyde` | 993,237 | +----------------+-----------------+------------------------------+-----------+ | Malonaldehyde | CCSD(T) | :obj:`malonaldehyde CCSD(T)` | 1,500 | +----------------+-----------------+------------------------------+-----------+ | Ethanol | DFT | :obj:`ethanol` | 555,092 | +----------------+-----------------+------------------------------+-----------+ | Ethanol | CCSD(T) | :obj:`ethanol CCSD(T)` | 2,000 | +----------------+-----------------+------------------------------+-----------+ | Toluene | DFT | :obj:`toluene` | 442,790 | +----------------+-----------------+------------------------------+-----------+ | Toluene | CCSD(T) | :obj:`toluene CCSD(T)` | 1,501 | +----------------+-----------------+------------------------------+-----------+ | Paracetamol | DFT | :obj:`paracetamol` | 106,490 | +----------------+-----------------+------------------------------+-----------+ | Azobenzene | DFT | :obj:`azobenzene` | 99,999 | +----------------+-----------------+------------------------------+-----------+ Args: root (string): Root directory where the dataset should be saved. name (string): Keyword of the trajectory that should be loaded. train (bool, optional): Determines whether the train or test split gets loaded for the coupled cluster trajectories. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ # noqa: E501 url = 'http://quantum-machine.org/gdml/data/npz' file_names = { 'benzene FHI-aims': 'benzene2018_dft.npz', 'benzene': 'benzene2017_dft.npz', 'benzene CCSD(T)': 'benzene_ccsd_t.zip', 'uracil': 'uracil_dft.npz', 'napthalene': 'naphthalene_dft.npz', 'aspirin': 'aspirin_dft.npz', 'aspirin CCSD': 'aspirin_ccsd.zip', 'salicylic acid': 'salicylic_dft.npz', 'malonaldehyde': 'malonaldehyde_dft.npz', 'malonaldehyde CCSD(T)': 'malonaldehyde_ccsd_t.zip', 'ethanol': 'ethanol_dft.npz', 'ethanol CCSD(T)': 'ethanol_ccsd_t.zip', 'toluene': 'toluene_dft.npz', 'toluene CCSD(T)': 'toluene_ccsd_t.zip', 'paracetamol': 'paracetamol_dft.npz', 'azobenzene': 'azobenzene_dft.npz', } def __init__(self, root: str, name: str, train: Optional[bool] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None): self.name = name assert name in self.file_names super().__init__(root, transform, pre_transform, pre_filter) if len(self.processed_file_names) == 1 and train is not None: raise ValueError( f"'{self.name}' dataset does not provide pre-defined splits " f"but 'train' argument is set to '{train}'") elif len(self.processed_file_names) == 2 and train is None: raise ValueError( f"'{self.name}' dataset does provide pre-defined splits but " f"'train' argument was not specified") idx = 0 if train is None or train else 1 self.data, self.slices = torch.load(self.processed_paths[idx]) def mean(self) -> float: return float(self.data.energy.mean()) @property def raw_dir(self) -> str: name = self.file_names[self.name].split('.')[0] return osp.join(self.root, name, 'raw') @property def processed_dir(self) -> str: name = self.file_names[self.name].split('.')[0] return osp.join(self.root, name, 'processed') @property def raw_file_names(self) -> List[str]: name = self.file_names[self.name] if name[-4:] == '.zip': return [name[:-4] + '-train.npz', name[:-4] + '-test.npz'] else: return [name] @property def processed_file_names(self) -> List[str]: name = self.file_names[self.name] return ['train.pt', 'test.pt'] if name[-4:] == '.zip' else ['data.pt'] def download(self): url = f'{self.url}/{self.file_names[self.name]}' path = download_url(url, self.raw_dir) if url[-4:] == '.zip': extract_zip(path, self.raw_dir) os.unlink(path) def process(self): it = zip(self.raw_paths, self.processed_paths) for raw_path, processed_path in it: raw_data = np.load(raw_path) z = torch.from_numpy(raw_data['z']).to(torch.long) pos = torch.from_numpy(raw_data['R']).to(torch.float) energy = torch.from_numpy(raw_data['E']).to(torch.float) force = torch.from_numpy(raw_data['F']).to(torch.float) data_list = [] for i in range(pos.size(0)): data = Data(z=z, pos=pos[i], energy=energy[i], force=force[i]) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) torch.save(self.collate(data_list), processed_path) def __repr__(self) -> str: return f"{self.__class__.__name__}({len(self)}, name='{self.name}')"