Source code for torch_geometric.datasets.fake

import random
from collections import defaultdict
from itertools import product
from typing import Callable, Optional

import torch

from import Data, HeteroData, InMemoryDataset
from torch_geometric.utils import coalesce, remove_self_loops, to_undirected

[docs]class FakeDataset(InMemoryDataset): r"""A fake dataset that returns randomly generated :class:`` objects. Args: num_graphs (int, optional): The number of graphs. (default: :obj:`1`) avg_num_nodes (int, optional): The average number of nodes in a graph. (default: :obj:`1000`) avg_degree (int, optional): The average degree per node. (default: :obj:`10`) num_channels (int, optional): The number of node features. (default: :obj:`64`) edge_dim (int, optional): The number of edge features. (default: :obj:`0`) num_classes (int, optional): The number of classes in the dataset. (default: :obj:`10`) task (str, optional): Whether to return node-level or graph-level labels (:obj:`"node"`, :obj:`"graph"`, :obj:`"auto"`). If set to :obj:`"auto"`, will return graph-level labels if :obj:`num_graphs > 1`, and node-level labels other-wise. (default: :obj:`"auto"`) is_undirected (bool, optional): Whether the graphs to generate are undirected. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) **kwargs (optional): Additional attributes and their shapes *e.g.* :obj:`global_features=5`. """ def __init__( self, num_graphs: int = 1, avg_num_nodes: int = 1000, avg_degree: int = 10, num_channels: int = 64, edge_dim: int = 0, num_classes: int = 10, task: str = "auto", is_undirected: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, **kwargs, ): super().__init__('.', transform) if task == 'auto': task = 'graph' if num_graphs > 1 else 'node' assert task in ['node', 'graph'] self.avg_num_nodes = max(avg_num_nodes, avg_degree) self.avg_degree = max(avg_degree, 1) self.num_channels = num_channels self.edge_dim = edge_dim self._num_classes = num_classes self.task = task self.is_undirected = is_undirected self.kwargs = kwargs data_list = [self.generate_data() for _ in range(max(num_graphs, 1))], self.slices = self.collate(data_list) def generate_data(self) -> Data: num_nodes = get_num_nodes(self.avg_num_nodes, self.avg_degree) data = Data() if self._num_classes > 0 and self.task == 'node': data.y = torch.randint(self._num_classes, (num_nodes, )) elif self._num_classes > 0 and self.task == 'graph': data.y = torch.tensor([random.randint(0, self._num_classes - 1)]) data.edge_index = get_edge_index(num_nodes, num_nodes, self.avg_degree, self.is_undirected, remove_loops=True) if self.num_channels > 0 and self.task == 'graph': data.x = torch.randn(num_nodes, self.num_channels) + data.y elif self.num_channels > 0 and self.task == 'node': data.x = torch.randn(num_nodes, self.num_channels) + data.y.unsqueeze(1) else: data.num_nodes = num_nodes if self.edge_dim > 1: data.edge_attr = torch.rand(data.num_edges, self.edge_dim) elif self.edge_dim == 1: data.edge_weight = torch.rand(data.num_edges) for feature_name, feature_shape in self.kwargs.items(): setattr(data, feature_name, torch.randn(feature_shape)) return data
[docs]class FakeHeteroDataset(InMemoryDataset): r"""A fake dataset that returns randomly generated :class:`` objects. Args: num_graphs (int, optional): The number of graphs. (default: :obj:`1`) num_node_types (int, optional): The number of node types. (default: :obj:`3`) num_edge_types (int, optional): The number of edge types. (default: :obj:`6`) avg_num_nodes (int, optional): The average number of nodes in a graph. (default: :obj:`1000`) avg_degree (int, optional): The average degree per node. (default: :obj:`10`) avg_num_channels (int, optional): The average number of node features. (default: :obj:`64`) edge_dim (int, optional): The number of edge features. (default: :obj:`0`) num_classes (int, optional): The number of classes in the dataset. (default: :obj:`10`) task (str, optional): Whether to return node-level or graph-level labels (:obj:`"node"`, :obj:`"graph"`, :obj:`"auto"`). If set to :obj:`"auto"`, will return graph-level labels if :obj:`num_graphs > 1`, and node-level labels other-wise. (default: :obj:`"auto"`) transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) **kwargs (optional): Additional attributes and their shapes *e.g.* :obj:`global_features=5`. """ def __init__( self, num_graphs: int = 1, num_node_types: int = 3, num_edge_types: int = 6, avg_num_nodes: int = 1000, avg_degree: int = 10, avg_num_channels: int = 64, edge_dim: int = 0, num_classes: int = 10, task: str = "auto", transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, **kwargs, ): super().__init__('.', transform) if task == 'auto': task = 'graph' if num_graphs > 1 else 'node' assert task in ['node', 'graph'] self.node_types = [f'v{i}' for i in range(max(num_node_types, 1))] edge_types = [] edge_type_product = list(product(self.node_types, self.node_types)) while len(edge_types) < max(num_edge_types, 1): edge_types.extend(edge_type_product) random.shuffle(edge_types) self.edge_types = [] count = defaultdict(lambda: 0) for edge_type in edge_types[:max(num_edge_types, 1)]: rel = f'e{count[edge_type]}' count[edge_type] += 1 self.edge_types.append((edge_type[0], rel, edge_type[1])) self.avg_num_nodes = max(avg_num_nodes, avg_degree) self.avg_degree = max(avg_degree, 1) self.num_channels = [ get_num_channels(avg_num_channels) for _ in self.node_types ] self.edge_dim = edge_dim self._num_classes = num_classes self.task = task self.kwargs = kwargs data_list = [self.generate_data() for _ in range(max(num_graphs, 1))], self.slices = self.collate(data_list) def generate_data(self) -> HeteroData: data = HeteroData() iterator = zip(self.node_types, self.num_channels) for i, (node_type, num_channels) in enumerate(iterator): num_nodes = get_num_nodes(self.avg_num_nodes, self.avg_degree) store = data[node_type] if num_channels > 0: store.x = torch.randn(num_nodes, num_channels) else: store.num_nodes = num_nodes if self._num_classes > 0 and self.task == 'node' and i == 0: store.y = torch.randint(self._num_classes, (num_nodes, )) for (src, rel, dst) in self.edge_types: store = data[(src, rel, dst)] store.edge_index = get_edge_index( data[src].num_nodes, data[dst].num_nodes, self.avg_degree, is_undirected=False, remove_loops=False, ) if self.edge_dim > 1: store.edge_attr = torch.rand(store.num_edges, self.edge_dim) elif self.edge_dim == 1: store.edge_weight = torch.rand(store.num_edges) pass if self._num_classes > 0 and self.task == 'graph': data.y = torch.tensor([random.randint(0, self._num_classes - 1)]) for feature_name, feature_shape in self.kwargs.items(): setattr(data, feature_name, torch.randn(feature_shape)) return data
############################################################################### def get_num_nodes(avg_num_nodes: int, avg_degree: int) -> int: min_num_nodes = max(3 * avg_num_nodes // 4, avg_degree) max_num_nodes = 5 * avg_num_nodes // 4 return random.randint(min_num_nodes, max_num_nodes) def get_num_channels(num_channels) -> int: min_num_channels = 3 * num_channels // 4 max_num_channels = 5 * num_channels // 4 return random.randint(min_num_channels, max_num_channels) def get_edge_index(num_src_nodes: int, num_dst_nodes: int, avg_degree: int, is_undirected: bool = False, remove_loops: bool = False) -> torch.Tensor: num_edges = num_src_nodes * avg_degree row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.int64) col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.int64) edge_index = torch.stack([row, col], dim=0) if remove_loops: edge_index, _ = remove_self_loops(edge_index) num_nodes = max(num_src_nodes, num_dst_nodes) if is_undirected: edge_index = to_undirected(edge_index, num_nodes=num_nodes) else: edge_index = coalesce(edge_index, num_nodes=num_nodes) return edge_index