Source code for torch_geometric.datasets.entities

from typing import Optional, Callable, List

import os
import logging
import os.path as osp
from collections import Counter

import torch
import numpy as np

from import InMemoryDataset, Data
from import download_url, extract_tar

[docs]class Entities(InMemoryDataset): r"""The relational entities networks "AIFB", "MUTAG", "BGS" and "AM" from the `"Modeling Relational Data with Graph Convolutional Networks" <>`_ paper. Training and test splits are given by node indices. Args: root (string): Root directory where the dataset should be saved. name (string): The name of the dataset (:obj:`"AIFB"`, :obj:`"MUTAG"`, :obj:`"BGS"`, :obj:`"AM"`). hetero (bool, optional): If set to :obj:`True`, will save the dataset as a :class:`` object. (default: :obj:`False`) transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) """ url = '{}.tgz' def __init__(self, root: str, name: str, hetero: bool = False, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None): = name.lower() self.hetero = hetero assert in ['aifb', 'am', 'mutag', 'bgs'] super().__init__(root, transform, pre_transform), self.slices = torch.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root,, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root,, 'processed') @property def num_relations(self) -> int: return + 1 @property def num_classes(self) -> int: return + 1 @property def raw_file_names(self) -> List[str]: return [ f'{}_stripped.nt.gz', 'completeDataset.tsv', 'trainingSet.tsv', 'testSet.tsv', ] @property def processed_file_names(self) -> str: return '' if self.hetero else '' def download(self): path = download_url(self.url.format(, self.root) extract_tar(path, self.raw_dir) os.unlink(path) def process(self): import gzip import pandas as pd import rdflib as rdf graph_file, task_file, train_file, test_file = self.raw_paths with hide_stdout(): g = rdf.Graph() with, 'rb') as f: g.parse(file=f, format='nt') freq = Counter(g.predicates()) relations = sorted(set(g.predicates()), key=lambda p: -freq.get(p, 0)) subjects = set(g.subjects()) objects = set(g.objects()) nodes = list(subjects.union(objects)) N = len(nodes) R = 2 * len(relations) relations_dict = {rel: i for i, rel in enumerate(relations)} nodes_dict = {node: i for i, node in enumerate(nodes)} edges = [] for s, p, o in g.triples((None, None, None)): src, dst, rel = nodes_dict[s], nodes_dict[o], relations_dict[p] edges.append([src, dst, 2 * rel]) edges.append([dst, src, 2 * rel + 1]) edges = torch.tensor(edges, dtype=torch.long).t().contiguous() perm = (N * R * edges[0] + R * edges[1] + edges[2]).argsort() edges = edges[:, perm] edge_index, edge_type = edges[:2], edges[2] if == 'am': label_header = 'label_cateogory' nodes_header = 'proxy' elif == 'aifb': label_header = 'label_affiliation' nodes_header = 'person' elif == 'mutag': label_header = 'label_mutagenic' nodes_header = 'bond' elif == 'bgs': label_header = 'label_lithogenesis' nodes_header = 'rock' labels_df = pd.read_csv(task_file, sep='\t') labels_set = set(labels_df[label_header].values.tolist()) labels_dict = {lab: i for i, lab in enumerate(list(labels_set))} nodes_dict = {np.unicode(key): val for key, val in nodes_dict.items()} train_labels_df = pd.read_csv(train_file, sep='\t') train_indices, train_labels = [], [] for nod, lab in zip(train_labels_df[nodes_header].values, train_labels_df[label_header].values): train_indices.append(nodes_dict[nod]) train_labels.append(labels_dict[lab]) train_idx = torch.tensor(train_indices, dtype=torch.long) train_y = torch.tensor(train_labels, dtype=torch.long) test_labels_df = pd.read_csv(test_file, sep='\t') test_indices, test_labels = [], [] for nod, lab in zip(test_labels_df[nodes_header].values, test_labels_df[label_header].values): test_indices.append(nodes_dict[nod]) test_labels.append(labels_dict[lab]) test_idx = torch.tensor(test_indices, dtype=torch.long) test_y = torch.tensor(test_labels, dtype=torch.long) data = Data(edge_index=edge_index, edge_type=edge_type, train_idx=train_idx, train_y=train_y, test_idx=test_idx, test_y=test_y, num_nodes=N) if self.hetero: data = data.to_heterogeneous(node_type_names=['v'])[data]), self.processed_paths[0]) def __repr__(self) -> str: return f'{}{self.__class__.__name__}()'
class hide_stdout(object): def __enter__(self): self.level = logging.getLogger().level logging.getLogger().setLevel(logging.ERROR) def __exit__(self, *args): logging.getLogger().setLevel(self.level)