Source code for torch_geometric.datasets.dgraph

import os.path as osp
from typing import Callable, Optional

import numpy as np
import torch

from import Data, InMemoryDataset, extract_zip
from torch_geometric.utils import index_to_mask

[docs]class DGraphFin(InMemoryDataset): r"""The DGraphFin networks from the `"DGraph: A Large-Scale Financial Dataset for Graph Anomaly Detection" <>`_ paper. It is a directed, unweighted dynamic graph consisting of millions of nodes and edges, representing a realistic user-to-user social network in financial industry. Node represents a Finvolution user, and an edge from one user to another means that the user regards the other user as the emergency contact person. Each edge is associated with a timestamp ranging from 1 to 821 and a type of emergency contact ranging from 0 to 11. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 3,700,550 - 4,300,999 - 17 - 2 """ url = "" def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0]) def download(self) -> None: raise RuntimeError( f"Dataset not found. Please download '{self.raw_file_names}' from " f"'{self.url}' and move it to '{self.raw_dir}'") @property def raw_file_names(self) -> str: return '' @property def processed_file_names(self) -> str: return '' @property def num_classes(self) -> int: return 2 def process(self) -> None: extract_zip(self.raw_paths[0], self.raw_dir, log=False) path = osp.join(self.raw_dir, "dgraphfin.npz") with np.load(path) as loader: x = torch.from_numpy(loader['x']).to(torch.float) y = torch.from_numpy(loader['y']).to(torch.long) edge_index = torch.from_numpy(loader['edge_index']).to(torch.long) edge_type = torch.from_numpy(loader['edge_type']).to(torch.long) edge_time = torch.from_numpy(loader['edge_timestamp']).to( torch.long) train_nodes = torch.from_numpy(loader['train_mask']).to(torch.long) val_nodes = torch.from_numpy(loader['valid_mask']).to(torch.long) test_nodes = torch.from_numpy(loader['test_mask']).to(torch.long) train_mask = index_to_mask(train_nodes, size=x.size(0)) val_mask = index_to_mask(val_nodes, size=x.size(0)) test_mask = index_to_mask(test_nodes, size=x.size(0)) data = Data(x=x, edge_index=edge_index.t(), edge_type=edge_type, edge_time=edge_time, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) data = data if self.pre_transform is None else self.pre_transform(data)[data], self.processed_paths[0])