from typing import Optional, Callable, List

import os
import os.path as osp
import glob
import pickle

import torch
import torch.nn.functional as F
from import (InMemoryDataset, Data, download_url,
                                  extract_zip, extract_tar)
from torch_geometric.utils import to_undirected

[docs]class GEDDataset(InMemoryDataset): r"""The GED datasets from the `"Graph Edit Distance Computation via Graph Neural Networks" <>`_ paper. GEDs can be accessed via the global attributes :obj:`ged` and :obj:`norm_ged` for all train/train graph pairs and all train/test graph pairs: .. code-block:: python dataset = GEDDataset(root, name="LINUX") data1, data2 = dataset[0], dataset[1] ged = dataset.ged[data1.i, data2.i] # GED between `data1` and `data2`. Note that GEDs are not available if both graphs are from the test set. For evaluation, it is recommended to pair up each graph from the test set with each graph in the training set. .. note:: :obj:`ALKANE` is missing GEDs for train/test graph pairs since they are not provided in the `official datasets <>`_. Args: root (string): Root directory where the dataset should be saved. name (string): The name of the dataset (one of :obj:`"AIDS700nef"`, :obj:`"LINUX"`, :obj:`"ALKANE"`, :obj:`"IMDBMulti"`). train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ url = '{}' datasets = { 'AIDS700nef': { 'id': '10czBPJDEzEDI2tq7Z7mkBjLhj55F-a2z', 'extract': extract_zip, 'pickle': '1OpV4bCHjBkdpqI6H5Mg0-BqlA2ee2eBW', }, 'LINUX': { 'id': '1nw0RRVgyLpit4V4XFQyDy0pI6wUEXSOI', 'extract': extract_tar, 'pickle': '14FDm3NSnrBvB7eNpLeGy5Bz6FjuCSF5v', }, 'ALKANE': { 'id': '1-LmxaWW3KulLh00YqscVEflbqr0g4cXt', 'extract': extract_tar, 'pickle': '15BpvMuHx77-yUGYgM27_sQett02HQNYu', }, 'IMDBMulti': { 'id': '12QxZ7EhYA7pJiF4cO-HuE8szhSOWcfST', 'extract': extract_zip, 'pickle': '1wy9VbZvZodkixxVIOuRllC-Lp-0zdoYZ', }, } # List of atoms contained in the AIDS700nef dataset: types = [ 'O', 'S', 'C', 'N', 'Cl', 'Br', 'B', 'Si', 'Hg', 'I', 'Bi', 'P', 'F', 'Cu', 'Ho', 'Pd', 'Ru', 'Pt', 'Sn', 'Li', 'Ga', 'Tb', 'As', 'Co', 'Pb', 'Sb', 'Se', 'Ni', 'Te' ] def __init__(self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None): = name assert in self.datasets.keys() super().__init__(root, transform, pre_transform, pre_filter) path = self.processed_paths[0] if train else self.processed_paths[1], self.slices = torch.load(path) path = osp.join(self.processed_dir, f'{}') self.ged = torch.load(path) path = osp.join(self.processed_dir, f'{}') self.norm_ged = torch.load(path) @property def raw_file_names(self) -> List[str]: return [osp.join(, s) for s in ['train', 'test']] @property def processed_file_names(self) -> List[str]: return [f'{}_{s}.pt' for s in ['training', 'test']] def download(self): name = self.datasets[]['id'] path = download_url(self.url.format(name), self.raw_dir) self.datasets[]['extract'](path, self.raw_dir) os.unlink(path) name = self.datasets[]['pickle'] path = download_url(self.url.format(name), self.raw_dir) os.rename(path, osp.join(self.raw_dir,, 'ged.pickle')) def process(self): import networkx as nx ids, Ns = [], [] for r_path, p_path in zip(self.raw_paths, self.processed_paths): names = glob.glob(osp.join(r_path, '*.gexf')) # Get the graph IDs given by the file name: ids.append(sorted([int(i.split(os.sep)[-1][:-5]) for i in names])) data_list = [] # Convert graphs in .gexf format to a NetworkX Graph: for i, idx in enumerate(ids[-1]): i = i if len(ids) == 1 else i + len(ids[0]) G = nx.read_gexf(osp.join(r_path, f'{idx}.gexf')) mapping = {name: j for j, name in enumerate(G.nodes())} G = nx.relabel_nodes(G, mapping) Ns.append(G.number_of_nodes()) edge_index = torch.tensor(list(G.edges)).t().contiguous() if edge_index.numel() == 0: edge_index = torch.empty((2, 0), dtype=torch.long) edge_index = to_undirected(edge_index, num_nodes=Ns[-1]) data = Data(edge_index=edge_index, i=i) data.num_nodes = Ns[-1] # Create a one-hot encoded feature matrix denoting the atom # type for the AIDS700nef dataset: if == 'AIDS700nef': x = torch.zeros(data.num_nodes, dtype=torch.long) for node, info in G.nodes(data=True): x[int(node)] = self.types.index(info['type']) data.x = F.one_hot(x, num_classes=len(self.types)).to( torch.float) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data), p_path) assoc = {idx: i for i, idx in enumerate(ids[0])} assoc.update({idx: i + len(ids[0]) for i, idx in enumerate(ids[1])}) path = osp.join(self.raw_dir,, 'ged.pickle') mat = torch.full((len(assoc), len(assoc)), float('inf')) with open(path, 'rb') as f: obj = pickle.load(f) xs, ys, gs = [], [], [] for (x, y), g in obj.items(): xs += [assoc[x]] ys += [assoc[y]] gs += [g] x, y = torch.tensor(xs), torch.tensor(ys) g = torch.tensor(gs, dtype=torch.float) mat[x, y], mat[y, x] = g, g path = osp.join(self.processed_dir, f'{}'), path) # Calculate the normalized GEDs: N = torch.tensor(Ns, dtype=torch.float) norm_mat = mat / (0.5 * (N.view(-1, 1) + N.view(1, -1))) path = osp.join(self.processed_dir, f'{}'), path) def __repr__(self) -> str: return f'{}({len(self)})'