Source code for

import os
import os.path as osp
from itertools import product
from typing import Callable, List, Optional

import numpy as np
import scipy.sparse as sp
import torch

from import (

[docs]class IMDB(InMemoryDataset): r"""A subset of the Internet Movie Database (IMDB), as collected in the `"MAGNN: Metapath Aggregated Graph Neural Network for Heterogeneous Graph Embedding" <>`_ paper. IMDB is a heterogeneous graph containing three types of entities - movies (4,278 nodes), actors (5,257 nodes), and directors (2,081 nodes). The movies are divided into three classes (action, comedy, drama) according to their genre. Movie features correspond to elements of a bag-of-words representation of its plot keywords. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = '' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return [ 'adjM.npz', 'features_0.npz', 'features_1.npz', 'features_2.npz', 'labels.npy', 'train_val_test_idx.npz' ] @property def processed_file_names(self) -> str: return '' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.remove(path) def process(self) -> None: data = HeteroData() node_types = ['movie', 'director', 'actor'] for i, node_type in enumerate(node_types): x = sp.load_npz(osp.join(self.raw_dir, f'features_{i}.npz')) data[node_type].x = torch.from_numpy(x.todense()).to(torch.float) y = np.load(osp.join(self.raw_dir, 'labels.npy')) data['movie'].y = torch.from_numpy(y).to(torch.long) split = np.load(osp.join(self.raw_dir, 'train_val_test_idx.npz')) for name in ['train', 'val', 'test']: idx = split[f'{name}_idx'] idx = torch.from_numpy(idx).to(torch.long) mask = torch.zeros(data['movie'].num_nodes, dtype=torch.bool) mask[idx] = True data['movie'][f'{name}_mask'] = mask s = {} N_m = data['movie'].num_nodes N_d = data['director'].num_nodes N_a = data['actor'].num_nodes s['movie'] = (0, N_m) s['director'] = (N_m, N_m + N_d) s['actor'] = (N_m + N_d, N_m + N_d + N_a) A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz')) for src, dst in product(node_types, node_types): A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo() if A_sub.nnz > 0: row = torch.from_numpy(A_sub.row).to(torch.long) col = torch.from_numpy(A_sub.col).to(torch.long) data[src, dst].edge_index = torch.stack([row, col], dim=0) if self.pre_transform is not None: data = self.pre_transform(data)[data], self.processed_paths[0]) def __repr__(self) -> str: return f'{self.__class__.__name__}()'