Source code for torch_geometric.datasets.hydro_net

import copy
import os
import os.path as osp
from dataclasses import dataclass
from functools import lru_cache
from glob import glob
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.utils.data import ConcatDataset, Subset

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)

try:
    from functools import cached_property
except ImportError:  # Python 3.7 support.

    def cached_property(func):
        return property(fget=lru_cache(maxsize=1)(func))


[docs]class HydroNet(InMemoryDataset): r"""The HydroNet dataest from the `"HydroNet: Benchmark Tasks for Preserving Intermolecular Interactions and Structural Motifs in Predictive and Generative Models for Molecular Data" <https://arxiv.org/abs/2012.00131>`_ paper, consisting of 5 million water clusters held together by hydrogen bonding networks. This dataset provides atomic coordinates and total energy in kcal/mol for the cluster. Args: root (string): Root directory where the dataset should be saved. name (string, optional): Name of the subset of the full dataset to use: :obj:`"small"` uses 500k graphs sampled from the :obj:`"medium"` dataset, :obj:`"medium"` uses 2.7m graphs with maximum size of 75 nodes. Mutually exclusive option with the clusters argument. (default :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) num_workers (int): Number of multiprocessing workers to use for pre-processing the dataset. (default :obj:`8`) clusters (int or List[int], optional): Select a subset of clusters from the full dataset. If set to :obj:`None`, will select all. (default :obj:`None`) use_processed (bool): Option to use a pre-processed version of the original :obj:`xyz` dataset. (default: :obj:`True`) """ def __init__( self, root: str, name: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, num_workers: int = 8, clusters: Optional[Union[int, List[int]]] = None, use_processed: bool = True, ): self.name = name self.num_workers = num_workers self.use_processed = use_processed super().__init__(root, transform, pre_transform) self.select_clusters(clusters) @property def raw_file_names(self) -> List[str]: return [f'W{c}_geoms_all.zip' for c in range(3, 31)] @property def processed_file_names(self) -> List[str]: return [f'W{c}_geoms_all.npz' for c in range(3, 31)] def download(self): token_file = Path(osp.join(self.raw_dir, 'use_processed')) if self.use_processed and token_file.exists(): return file = RemoteFile.hydronet_splits() file.unpack_to(self.raw_dir) if self.use_processed: file = RemoteFile.processed_dataset() file.unpack_to(self.raw_dir) token_file.touch() return file = RemoteFile.raw_dataset() file.unpack_to(self.raw_dir) folder_name, _ = osp.splitext(file.name) files = glob(osp.join(self.raw_dir, folder_name, '*.zip')) for f in files: dst = osp.join(self.raw_dir, osp.basename(f)) os.rename(f, dst) os.rmdir(osp.join(self.raw_dir, folder_name)) def process(self): if self.use_processed: return self._unpack_processed() from tqdm.contrib.concurrent import process_map self._partitions = process_map( self._create_partitions, self.raw_paths, max_workers=self.num_workers, position=0, leave=True, ) def _unpack_processed(self): files = glob(osp.join(self.raw_dir, '*.npz')) for f in files: dst = osp.join(self.processed_dir, osp.basename(f)) os.rename(f, dst) def _create_partitions(self, file) -> 'Partition': name = osp.basename(file) name, _ = osp.splitext(name) return Partition(self.root, name, self.transform, self.pre_transform) def select_clusters( self, clusters: Optional[Union[int, List[int]]], ) -> 'Optional[List[Partition]]': if self.name is not None: clusters = self._validate_name(clusters) self._partitions = [self._create_partitions(f) for f in self.raw_paths] if clusters is None: return clusters = [clusters] if isinstance(clusters, int) else clusters def is_valid_cluster(x): return isinstance(x, int) and x >= 3 and x <= 30 if not all([is_valid_cluster(x) for x in clusters]): raise ValueError( "Selected clusters must be an integer in the range [3, 30]") self._partitions = [self._partitions[c - 3] for c in clusters] def _validate_name(self, clusters) -> List[int]: if clusters is not None: raise ValueError("'name' and 'clusters' are mutually exclusive") if self.name not in ['small', 'medium']: raise ValueError(f"Invalid subset name '{self.name}'. " f"Must be either 'small' or 'medium'") return range(3, 26) @cached_property def _dataset(self) -> ConcatDataset: dataset = ConcatDataset(self._partitions) if self.name == "small": dataset = self._load_small_split(dataset) return dataset def _load_small_split(self, dataset: ConcatDataset) -> Subset: split_file = osp.join(self.processed_dir, 'split_00_small.npz') with np.load(split_file) as split: train_idx = split['train_idx'] val_idx = split['val_idx'] all_idx = np.concatenate([train_idx, val_idx]) dataset = Subset(dataset, all_idx) return dataset
[docs] def len(self) -> int: return len(self._dataset)
def get(self, idx: int) -> Data: return self._dataset[idx]
def get_num_clusters(filepath) -> int: name = osp.basename(filepath) return int(name[1:name.find('_')]) def read_energy(file: str, chunk_size: int) -> np.ndarray: import pandas as pd def skipatoms(i: int): return (i - 1) % chunk_size != 0 if chunk_size - 2 == 11 * 3: # Manually handle bad lines in W11 df = pd.read_table(file, header=None, dtype="string", skiprows=skipatoms) df = df[0].str.split().str[-1].astype(np.float32) else: df = pd.read_table(file, sep=r'\s+', names=["label", "energy"], dtype=np.float32, skiprows=skipatoms, usecols=['energy'], memory_map=True) return df.to_numpy().squeeze() def read_atoms(file: str, chunk_size: int) -> Tuple[np.ndarray, np.ndarray]: import pandas as pd def skipheaders(i: int): return i % chunk_size == 0 or (i - 1) % chunk_size == 0 dtypes = { 'atom': 'string', 'x': np.float16, 'y': np.float16, 'z': np.float16 } df = pd.read_table(file, sep=r'\s+', names=list(dtypes.keys()), dtype=dtypes, skiprows=skipheaders, memory_map=True) z = np.ones(len(df), dtype=np.int8) z[(df.atom == 'O').to_numpy(dtype=np.bool_)] = 8 pos = df.iloc[:, 1:4].to_numpy() num_nodes = (chunk_size - 2) num_graphs = z.shape[0] // num_nodes z.shape = (num_graphs, num_nodes) pos.shape = (num_graphs, num_nodes, 3) return (z, pos) @dataclass class RemoteFile: url: str name: str def unpack_to(self, dest_folder: str): file = download_url(self.url, dest_folder, filename=self.name) extract_zip(file, dest_folder) os.unlink(file) @staticmethod def raw_dataset() -> 'RemoteFile': return RemoteFile( url='https://figshare.com/ndownloader/files/38063847', name='W3-W30_all_geoms_TTM2.1-F.zip') @staticmethod def processed_dataset() -> 'RemoteFile': return RemoteFile( url='https://figshare.com/ndownloader/files/38075781', name='W3-W30_pyg_processed.zip') @staticmethod def hydronet_splits() -> 'RemoteFile': return RemoteFile( url="https://figshare.com/ndownloader/files/38075904", name="hydronet_splits.zip") class Partition(InMemoryDataset): def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, ): self.name = name self.num_clusters = get_num_clusters(name) super().__init__(root, transform, pre_transform, pre_filter=None, log=False) self.is_loaded = False @property def raw_file_names(self) -> List[str]: return [self.name + ".zip"] @property def processed_file_names(self) -> List[str]: return [self.name + '.npz'] def process(self): num_nodes = self.num_clusters * 3 chunk_size = num_nodes + 2 z, pos = read_atoms(self.raw_paths[0], chunk_size) y = read_energy(self.raw_paths[0], chunk_size) np.savez(self.processed_paths[0], z=z, pos=pos, y=y, num_graphs=z.shape[0]) def load(self): if self.is_loaded: return with np.load(self.processed_paths[0]) as npzfile: self.z = npzfile['z'] self.pos = npzfile['pos'] self.y = npzfile['y'] num_graphs = int(npzfile['num_graphs']) self._data_list = num_graphs * [None] self.is_loaded = True @cached_property def num_graphs(self) -> int: with np.load(self.processed_paths[0]) as npzfile: return int(npzfile['num_graphs']) def len(self) -> int: return self.num_graphs def get(self, idx: int) -> Data: self.load() if self._data_list[idx] is not None: return copy.copy(self._data_list[idx]) data = Data( z=torch.from_numpy(self.z[idx, :]), pos=torch.from_numpy(self.pos[idx, :, :]), y=torch.tensor(self.y[idx]), ) if self.pre_transform is not None: data = self.pre_transform(data) self._data_list[idx] = copy.copy(data) return data