Source code for torch_geometric.datasets.karate

from typing import Callable, Optional

import numpy as np
import torch

from torch_geometric.data import Data, InMemoryDataset


[docs]class KarateClub(InMemoryDataset): r"""Zachary's karate club network from the `"An Information Flow Model for Conflict and Fission in Small Groups" <http://www1.ind.ku.dk/complexLearning/zachary1977.pdf>`_ paper, containing 34 nodes, connected by 156 (undirected and unweighted) edges. Every node is labeled by one of four classes obtained via modularity-based clustering, following the `"Semi-supervised Classification with Graph Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_ paper. Training is based on a single labeled example per class, *i.e.* a total number of 4 labeled nodes. Args: transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) Stats: .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 34 - 156 - 34 - 4 """ def __init__(self, transform: Optional[Callable] = None): super().__init__('.', transform) import networkx as nx G = nx.karate_club_graph() x = torch.eye(G.number_of_nodes(), dtype=torch.float) if hasattr(nx, 'to_scipy_sparse_array'): adj = nx.to_scipy_sparse_array(G).tocoo() else: adj = nx.to_scipy_sparse_matrix(G).tocoo() row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long) col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long) edge_index = torch.stack([row, col], dim=0) # Create communities. y = torch.tensor([ 1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0 ], dtype=torch.long) # Select a single training node for each community # (we just use the first one). train_mask = torch.zeros(y.size(0), dtype=torch.bool) for i in range(int(y.max()) + 1): train_mask[(y == i).nonzero(as_tuple=False)[0]] = True data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask) self.data, self.slices = self.collate([data])