import glob
import os
import os.path as osp
import pickle
from typing import Callable, List, Optional
import torch
from torch_geometric.data import (
Data,
InMemoryDataset,
download_google_url,
extract_tar,
extract_zip,
)
from torch_geometric.io import fs
from torch_geometric.utils import one_hot, to_undirected
[docs]class GEDDataset(InMemoryDataset):
r"""The GED datasets from the `"Graph Edit Distance Computation via Graph
Neural Networks" <https://arxiv.org/abs/1808.05689>`_ paper.
GEDs can be accessed via the global attributes :obj:`ged` and
:obj:`norm_ged` for all train/train graph pairs and all train/test graph
pairs:
.. code-block:: python
dataset = GEDDataset(root, name="LINUX")
data1, data2 = dataset[0], dataset[1]
ged = dataset.ged[data1.i, data2.i] # GED between `data1` and `data2`.
Note that GEDs are not available if both graphs are from the test set.
For evaluation, it is recommended to pair up each graph from the test set
with each graph in the training set.
.. note::
:obj:`ALKANE` is missing GEDs for train/test graph pairs since they are
not provided in the `official datasets
<https://github.com/yunshengb/SimGNN>`_.
Args:
root (str): Root directory where the dataset should be saved.
name (str): The name of the dataset (one of :obj:`"AIDS700nef"`,
:obj:`"LINUX"`, :obj:`"ALKANE"`, :obj:`"IMDBMulti"`).
train (bool, optional): If :obj:`True`, loads the training dataset,
otherwise the test dataset. (default: :obj:`True`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
.. list-table::
:widths: 20 10 10 10 10 10
:header-rows: 1
* - Name
- #graphs
- #nodes
- #edges
- #features
- #classes
* - AIDS700nef
- 700
- ~8.9
- ~17.6
- 29
- 0
* - LINUX
- 1,000
- ~7.6
- ~13.9
- 0
- 0
* - ALKANE
- 150
- ~8.9
- ~15.8
- 0
- 0
* - IMDBMulti
- 1,500
- ~13.0
- ~131.9
- 0
- 0
"""
datasets = {
'AIDS700nef': {
'id': '10czBPJDEzEDI2tq7Z7mkBjLhj55F-a2z',
'extract': extract_zip,
'pickle': '1OpV4bCHjBkdpqI6H5Mg0-BqlA2ee2eBW',
},
'LINUX': {
'id': '1nw0RRVgyLpit4V4XFQyDy0pI6wUEXSOI',
'extract': extract_tar,
'pickle': '14FDm3NSnrBvB7eNpLeGy5Bz6FjuCSF5v',
},
'ALKANE': {
'id': '1-LmxaWW3KulLh00YqscVEflbqr0g4cXt',
'extract': extract_tar,
'pickle': '15BpvMuHx77-yUGYgM27_sQett02HQNYu',
},
'IMDBMulti': {
'id': '12QxZ7EhYA7pJiF4cO-HuE8szhSOWcfST',
'extract': extract_zip,
'pickle': '1wy9VbZvZodkixxVIOuRllC-Lp-0zdoYZ',
},
}
# List of atoms contained in the AIDS700nef dataset:
types = [
'O', 'S', 'C', 'N', 'Cl', 'Br', 'B', 'Si', 'Hg', 'I', 'Bi', 'P', 'F',
'Cu', 'Ho', 'Pd', 'Ru', 'Pt', 'Sn', 'Li', 'Ga', 'Tb', 'As', 'Co', 'Pb',
'Sb', 'Se', 'Ni', 'Te'
]
def __init__(
self,
root: str,
name: str,
train: bool = True,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
force_reload: bool = False,
) -> None:
self.name = name
assert self.name in self.datasets.keys()
super().__init__(root, transform, pre_transform, pre_filter,
force_reload=force_reload)
path = self.processed_paths[0] if train else self.processed_paths[1]
self.load(path)
path = osp.join(self.processed_dir, f'{self.name}_ged.pt')
self.ged = fs.torch_load(path)
path = osp.join(self.processed_dir, f'{self.name}_norm_ged.pt')
self.norm_ged = fs.torch_load(path)
@property
def raw_file_names(self) -> List[str]:
# Returns, e.g., ['LINUX/train', 'LINUX/test']
return [osp.join(self.name, s) for s in ['train', 'test']]
@property
def processed_file_names(self) -> List[str]:
# Returns, e.g., ['LINUX_training.pt', 'LINUX_test.pt']
return [f'{self.name}_{s}.pt' for s in ['training', 'test']]
def download(self) -> None:
# Downloads the .tar/.zip file of the graphs and extracts them:
id = self.datasets[self.name]['id']
assert isinstance(id, str)
path = download_google_url(id, self.raw_dir, 'data')
extract_fn = self.datasets[self.name]['extract']
assert callable(extract_fn)
extract_fn(path, self.raw_dir)
os.unlink(path)
# Downloads the pickle file containing pre-computed GEDs:
id = self.datasets[self.name]['pickle']
assert isinstance(id, str)
path = download_google_url(id, self.raw_dir, 'ged.pickle')
def process(self) -> None:
import networkx as nx
ids, Ns = [], []
# Iterating over paths for raw and processed data (train + test):
for r_path, p_path in zip(self.raw_paths, self.processed_paths):
# Find the paths of all raw graphs:
names = glob.glob(osp.join(r_path, '*.gexf'))
# Get sorted graph IDs given filename: 123.gexf -> 123
ids.append(sorted([int(osp.basename(i)[:-5]) for i in names]))
data_list = []
# Convert graphs in .gexf format to a NetworkX Graph:
for i, idx in enumerate(ids[-1]):
i = i if len(ids) == 1 else i + len(ids[0])
# Reading the raw `*.gexf` graph:
G = nx.read_gexf(osp.join(r_path, f'{idx}.gexf'))
# Mapping of nodes in `G` to a contiguous number:
mapping = {name: j for j, name in enumerate(G.nodes())}
G = nx.relabel_nodes(G, mapping)
Ns.append(G.number_of_nodes())
edge_index = torch.tensor(list(G.edges)).t().contiguous()
if edge_index.numel() == 0:
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_index = to_undirected(edge_index, num_nodes=Ns[-1])
data = Data(edge_index=edge_index, i=i)
data.num_nodes = Ns[-1]
# Create a one-hot encoded feature matrix denoting the atom
# type (for the `AIDS700nef` dataset):
if self.name == 'AIDS700nef':
assert data.num_nodes is not None
x = torch.zeros(data.num_nodes, dtype=torch.long)
for node, info in G.nodes(data=True):
x[int(node)] = self.types.index(info['type'])
data.x = one_hot(x, num_classes=len(self.types))
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)
self.save(data_list, p_path)
assoc = {idx: i for i, idx in enumerate(ids[0])}
assoc.update({idx: i + len(ids[0]) for i, idx in enumerate(ids[1])})
# Extracting ground-truth GEDs from the GED pickle file
path = osp.join(self.raw_dir, self.name, 'ged.pickle')
# Initialize GEDs as float('inf'):
mat = torch.full((len(assoc), len(assoc)), float('inf'))
with open(path, 'rb') as f:
obj = pickle.load(f)
xs, ys, gs = [], [], []
for (_x, _y), g in obj.items():
xs += [assoc[_x]]
ys += [assoc[_y]]
gs += [g]
# The pickle file does not contain GEDs for test graph pairs, i.e.
# GEDs for (test_graph, test_graph) pairs are still float('inf'):
x, y = torch.tensor(xs), torch.tensor(ys)
ged = torch.tensor(gs, dtype=torch.float)
mat[x, y], mat[y, x] = ged, ged
path = osp.join(self.processed_dir, f'{self.name}_ged.pt')
torch.save(mat, path)
# Calculate the normalized GEDs:
N = torch.tensor(Ns, dtype=torch.float)
norm_mat = mat / (0.5 * (N.view(-1, 1) + N.view(1, -1)))
path = osp.join(self.processed_dir, f'{self.name}_norm_ged.pt')
torch.save(norm_mat, path)
def __repr__(self) -> str:
return f'{self.name}({len(self)})'