Source code for torch_geometric.datasets.hydro_net

import copy
import os
import os.path as osp
from dataclasses import dataclass
from functools import cached_property
from glob import glob
from pathlib import Path
from typing import Callable, List, MutableSequence, Optional, Tuple, Union

import numpy as np
import torch
from torch.utils.data import ConcatDataset, Subset

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)
from torch_geometric.data.data import BaseData


[docs]class HydroNet(InMemoryDataset): r"""The HydroNet dataest from the `"HydroNet: Benchmark Tasks for Preserving Intermolecular Interactions and Structural Motifs in Predictive and Generative Models for Molecular Data" <https://arxiv.org/abs/2012.00131>`_ paper, consisting of 5 million water clusters held together by hydrogen bonding networks. This dataset provides atomic coordinates and total energy in kcal/mol for the cluster. Args: root (str): Root directory where the dataset should be saved. name (str, optional): Name of the subset of the full dataset to use: :obj:`"small"` uses 500k graphs sampled from the :obj:`"medium"` dataset, :obj:`"medium"` uses 2.7m graphs with maximum size of 75 nodes. Mutually exclusive option with the clusters argument. (default :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) num_workers (int): Number of multiprocessing workers to use for pre-processing the dataset. (default :obj:`8`) clusters (int or List[int], optional): Select a subset of clusters from the full dataset. If set to :obj:`None`, will select all. (default :obj:`None`) use_processed (bool): Option to use a pre-processed version of the original :obj:`xyz` dataset. (default: :obj:`True`) """ def __init__( self, root: str, name: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, num_workers: int = 8, clusters: Optional[Union[int, List[int]]] = None, use_processed: bool = True, ) -> None: self.name = name self.num_workers = num_workers self.use_processed = use_processed super().__init__(root, transform, pre_transform, force_reload=force_reload) self.select_clusters(clusters) @property def raw_file_names(self) -> List[str]: return [f'W{c}_geoms_all.zip' for c in range(3, 31)] @property def processed_file_names(self) -> List[str]: return [f'W{c}_geoms_all.npz' for c in range(3, 31)] def download(self) -> None: token_file = Path(osp.join(self.raw_dir, 'use_processed')) if self.use_processed and token_file.exists(): return file = RemoteFile.hydronet_splits() file.unpack_to(self.raw_dir) if self.use_processed: file = RemoteFile.processed_dataset() file.unpack_to(self.raw_dir) token_file.touch() return file = RemoteFile.raw_dataset() file.unpack_to(self.raw_dir) folder_name, _ = osp.splitext(file.name) files = glob(osp.join(self.raw_dir, folder_name, '*.zip')) for f in files: dst = osp.join(self.raw_dir, osp.basename(f)) os.rename(f, dst) os.rmdir(osp.join(self.raw_dir, folder_name)) def process(self) -> None: if self.use_processed: return self._unpack_processed() from tqdm.contrib.concurrent import process_map self._partitions = process_map( self._create_partitions, self.raw_paths, max_workers=self.num_workers, position=0, leave=True, ) def _unpack_processed(self) -> None: files = glob(osp.join(self.raw_dir, '*.npz')) for f in files: dst = osp.join(self.processed_dir, osp.basename(f)) os.rename(f, dst) def _create_partitions(self, file: str) -> 'Partition': name = osp.basename(file) name, _ = osp.splitext(name) return Partition(self.root, name, self.transform, self.pre_transform) def select_clusters( self, clusters: Optional[Union[int, List[int]]], ) -> None: if self.name is not None: clusters = self._validate_name(clusters) self._partitions = [self._create_partitions(f) for f in self.raw_paths] if clusters is None: return clusters = [clusters] if isinstance(clusters, int) else clusters def is_valid_cluster(x: Union[int, List[int]]) -> bool: return isinstance(x, int) and x >= 3 and x <= 30 if not all([is_valid_cluster(x) for x in clusters]): raise ValueError( "Selected clusters must be an integer in the range [3, 30]") self._partitions = [self._partitions[c - 3] for c in clusters] def _validate_name( self, clusters: Optional[Union[int, List[int]]], ) -> List[int]: if clusters is not None: raise ValueError("'name' and 'clusters' are mutually exclusive") if self.name not in ['small', 'medium']: raise ValueError(f"Invalid subset name '{self.name}'. " f"Must be either 'small' or 'medium'") return list(range(3, 26)) @cached_property def _dataset(self) -> Union[ConcatDataset, Subset]: dataset: ConcatDataset = ConcatDataset(self._partitions) if self.name == "small": return self._load_small_split(dataset) return dataset def _load_small_split(self, dataset: ConcatDataset) -> Subset: split_file = osp.join(self.processed_dir, 'split_00_small.npz') with np.load(split_file) as split: train_idx = split['train_idx'] val_idx = split['val_idx'] all_idx = np.concatenate([train_idx, val_idx]) return Subset(dataset, all_idx) def len(self) -> int: return len(self._dataset) def get(self, idx: int) -> Data: return self._dataset[idx]
def get_num_clusters(filepath: str) -> int: name = osp.basename(filepath) return int(name[1:name.find('_')]) def read_energy(file: str, chunk_size: int) -> np.ndarray: import pandas as pd def skipatoms(i: int) -> bool: return (i - 1) % chunk_size != 0 if chunk_size - 2 == 11 * 3: # Manually handle bad lines in W11 df = pd.read_table(file, header=None, dtype="string", skiprows=skipatoms) df = df[0].str.split().str[-1].astype(np.float32) else: df = pd.read_table(file, sep=r'\s+', names=["label", "energy"], dtype=np.float32, skiprows=skipatoms, usecols=['energy'], memory_map=True) return df.to_numpy().squeeze() def read_atoms(file: str, chunk_size: int) -> Tuple[np.ndarray, np.ndarray]: import pandas as pd def skipheaders(i: int) -> bool: return i % chunk_size == 0 or (i - 1) % chunk_size == 0 dtypes = { 'atom': 'string', 'x': np.float16, 'y': np.float16, 'z': np.float16 } df = pd.read_table(file, sep=r'\s+', names=list(dtypes.keys()), dtype=dtypes, skiprows=skipheaders, memory_map=True) z = np.ones(len(df), dtype=np.int8) z[(df.atom == 'O').to_numpy(dtype=np.bool_)] = 8 pos = df.iloc[:, 1:4].to_numpy() num_nodes = (chunk_size - 2) num_graphs = z.shape[0] // num_nodes z.shape = (num_graphs, num_nodes) pos.shape = (num_graphs, num_nodes, 3) return (z, pos) @dataclass class RemoteFile: url: str name: str def unpack_to(self, dest_folder: str) -> None: file = download_url(self.url, dest_folder, filename=self.name) extract_zip(file, dest_folder) os.unlink(file) @staticmethod def raw_dataset() -> 'RemoteFile': return RemoteFile( url='https://figshare.com/ndownloader/files/38063847', name='W3-W30_all_geoms_TTM2.1-F.zip') @staticmethod def processed_dataset() -> 'RemoteFile': return RemoteFile( url='https://figshare.com/ndownloader/files/38075781', name='W3-W30_pyg_processed.zip') @staticmethod def hydronet_splits() -> 'RemoteFile': return RemoteFile( url="https://figshare.com/ndownloader/files/38075904", name="hydronet_splits.zip") class Partition(InMemoryDataset): def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, ) -> None: self.name = name self.num_clusters = get_num_clusters(name) super().__init__(root, transform, pre_transform, pre_filter=None, log=False) self.is_loaded = False @property def raw_file_names(self) -> List[str]: return [self.name + ".zip"] @property def processed_file_names(self) -> List[str]: return [self.name + '.npz'] def process(self) -> None: num_nodes = self.num_clusters * 3 chunk_size = num_nodes + 2 z, pos = read_atoms(self.raw_paths[0], chunk_size) y = read_energy(self.raw_paths[0], chunk_size) np.savez(self.processed_paths[0], z=z, pos=pos, y=y, num_graphs=z.shape[0]) def _load(self) -> None: if self.is_loaded: return None with np.load(self.processed_paths[0]) as npzfile: self.z = npzfile['z'] self.pos = npzfile['pos'] self.y = npzfile['y'] numel = int(npzfile['num_graphs']) self._data_list: MutableSequence[Optional[BaseData]] = [None] * numel self.is_loaded = True @cached_property def num_graphs(self) -> int: with np.load(self.processed_paths[0]) as npzfile: return int(npzfile['num_graphs']) def len(self) -> int: return self.num_graphs def get(self, idx: int) -> Data: self._load() if self._data_list[idx] is not None: cached_data = self._data_list[idx] assert isinstance(cached_data, Data) return copy.copy(cached_data) data = Data( z=torch.from_numpy(self.z[idx, :]), pos=torch.from_numpy(self.pos[idx, :, :]), y=torch.tensor(self.y[idx]), ) if self.pre_transform is not None: data = self.pre_transform(data) self._data_list[idx] = copy.copy(data) return data