Source code for torch_geometric.datasets.ged_dataset

import os
import os.path as osp
import glob
import pickle

import torch
import torch.nn.functional as F
import networkx as nx
from torch_geometric.data import (InMemoryDataset, Data, download_url,
                                  extract_zip, extract_tar)
from torch_geometric.utils import to_undirected


[docs]class GEDDataset(InMemoryDataset): r"""The GED datasets from the `"Graph Edit Distance Computation via Graph Neural Networks" <https://arxiv.org/abs/1808.05689>`_ paper. GEDs can be accessed via the global attributes :obj:`ged` and :obj:`norm_ged` for all train/train graph pairs and all train/test graph pairs: .. code-block:: python dataset = GEDDataset(root, name="LINUX") data1, data2 = dataset[0], dataset[1] ged = dataset.ged[data1.i, data2.i] # GED between `data1` and `data2`. .. note:: :obj:`ALKANE` is missing GEDs for train/test graph pairs since they are not provided in the `official datasets <https://github.com/yunshengb/SimGNN>`_. Args: root (string): Root directory where the dataset should be saved. name (string): The name of the dataset (one of :obj:`"AIDS700nef"`, :obj:`"LINUX"`, :obj:`"ALKANE"`, :obj:`"IMDBMulti"`). train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ url = 'https://drive.google.com/uc?export=download&id={}' datasets = { 'AIDS700nef': { 'id': '10czBPJDEzEDI2tq7Z7mkBjLhj55F-a2z', 'extract': extract_zip, 'pickle': '1OpV4bCHjBkdpqI6H5Mg0-BqlA2ee2eBW', }, 'LINUX': { 'id': '1nw0RRVgyLpit4V4XFQyDy0pI6wUEXSOI', 'extract': extract_tar, 'pickle': '14FDm3NSnrBvB7eNpLeGy5Bz6FjuCSF5v', }, 'ALKANE': { 'id': '1-LmxaWW3KulLh00YqscVEflbqr0g4cXt', 'extract': extract_tar, 'pickle': '15BpvMuHx77-yUGYgM27_sQett02HQNYu', }, 'IMDBMulti': { 'id': '12QxZ7EhYA7pJiF4cO-HuE8szhSOWcfST', 'extract': extract_zip, 'pickle': '1wy9VbZvZodkixxVIOuRllC-Lp-0zdoYZ', }, } types = [ 'O', 'S', 'C', 'N', 'Cl', 'Br', 'B', 'Si', 'Hg', 'I', 'Bi', 'P', 'F', 'Cu', 'Ho', 'Pd', 'Ru', 'Pt', 'Sn', 'Li', 'Ga', 'Tb', 'As', 'Co', 'Pb', 'Sb', 'Se', 'Ni', 'Te' ] def __init__(self, root, name, train=True, transform=None, pre_transform=None, pre_filter=None): self.name = name assert self.name in self.datasets.keys() super(GEDDataset, self).__init__(root, transform, pre_transform, pre_filter) path = self.processed_paths[0] if train else self.processed_paths[1] self.data, self.slices = torch.load(path) path = osp.join(self.processed_dir, '{}_ged.pt'.format(self.name)) self.ged = torch.load(path) path = osp.join(self.processed_dir, '{}_norm_ged.pt'.format(self.name)) self.norm_ged = torch.load(path) @property def raw_file_names(self): return [osp.join(self.name, s) for s in ['train', 'test']] @property def processed_file_names(self): return ['{}_{}.pt'.format(self.name, s) for s in ['training', 'test']] def download(self): name = self.datasets[self.name]['id'] path = download_url(self.url.format(name), self.raw_dir) self.datasets[self.name]['extract'](path, self.raw_dir) os.unlink(path) name = self.datasets[self.name]['pickle'] path = download_url(self.url.format(name), self.raw_dir) os.rename(path, osp.join(self.raw_dir, self.name, 'ged.pickle')) def process(self): ids, Ns = [], [] for r_path, p_path in zip(self.raw_paths, self.processed_paths): names = glob.glob(osp.join(r_path, '*.gexf')) ids.append(sorted([int(i.split(os.sep)[-1][:-5]) for i in names])) data_list = [] for i, idx in enumerate(ids[-1]): i = i if len(ids) == 1 else i + len(ids[0]) G = nx.read_gexf(osp.join(r_path, '{}.gexf'.format(idx))) mapping = {name: j for j, name in enumerate(G.nodes())} G = nx.relabel_nodes(G, mapping) Ns.append(G.number_of_nodes()) edge_index = torch.tensor(list(G.edges)).t().contiguous() if edge_index.numel() == 0: edge_index = torch.empty((2, 0), dtype=torch.long) edge_index = to_undirected(edge_index, num_nodes=Ns[-1]) data = Data(edge_index=edge_index, i=i) data.num_nodes = Ns[-1] if self.name == 'AIDS700nef': x = torch.zeros(data.num_nodes, dtype=torch.long) for node, info in G.nodes(data=True): x[int(node)] = self.types.index(info['type']) data.x = F.one_hot(x, num_classes=len(self.types)).to( torch.float) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) torch.save(self.collate(data_list), p_path) assoc = {idx: i for i, idx in enumerate(ids[0])} assoc.update({idx: i + len(ids[0]) for i, idx in enumerate(ids[1])}) path = osp.join(self.raw_dir, self.name, 'ged.pickle') mat = torch.full((len(assoc), len(assoc)), float('inf')) with open(path, 'rb') as f: obj = pickle.load(f) xs, ys, gs = [], [], [] for (x, y), g in obj.items(): xs += [assoc[x]] ys += [assoc[y]] gs += [g] x, y = torch.tensor(xs), torch.tensor(ys) g = torch.tensor(gs, dtype=torch.float) mat[x, y], mat[y, x] = g, g path = osp.join(self.processed_dir, '{}_ged.pt'.format(self.name)) torch.save(mat, path) N = torch.tensor(Ns, dtype=torch.float) norm_mat = mat / (0.5 * (N.view(-1, 1) + N.view(1, -1))) path = osp.join(self.processed_dir, '{}_norm_ged.pt'.format(self.name)) torch.save(norm_mat, path) def __repr__(self): return '{}({})'.format(self.name, len(self))