Source code for torch_geometric.datasets.karate

from typing import Callable, Optional

import torch

from torch_geometric.data import Data, InMemoryDataset


[docs]class KarateClub(InMemoryDataset): r"""Zachary's karate club network from the `"An Information Flow Model for Conflict and Fission in Small Groups" <https://www.journals.uchicago.edu/doi/abs/10.1086/jar.33.4.3629752>`_ paper, containing 34 nodes, connected by 156 (undirected and unweighted) edges. Every node is labeled by one of four classes obtained via modularity-based clustering, following the `"Semi-supervised Classification with Graph Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_ paper. Training is based on a single labeled example per class, *i.e.* a total number of 4 labeled nodes. Args: transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) **STATS:** .. list-table:: :widths: 10 10 10 10 :header-rows: 1 * - #nodes - #edges - #features - #classes * - 34 - 156 - 34 - 4 """ def __init__(self, transform: Optional[Callable] = None): super().__init__(None, transform) row = [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 10, 10, 10, 11, 12, 12, 13, 13, 13, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27, 27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33 ] col = [ 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 17, 19, 21, 31, 0, 2, 3, 7, 13, 17, 19, 21, 30, 0, 1, 3, 7, 8, 9, 13, 27, 28, 32, 0, 1, 2, 7, 12, 13, 0, 6, 10, 0, 6, 10, 16, 0, 4, 5, 16, 0, 1, 2, 3, 0, 2, 30, 32, 33, 2, 33, 0, 4, 5, 0, 0, 3, 0, 1, 2, 3, 33, 32, 33, 32, 33, 5, 6, 0, 1, 32, 33, 0, 1, 33, 32, 33, 0, 1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29, 33, 2, 23, 24, 33, 2, 31, 33, 23, 26, 32, 33, 1, 8, 32, 33, 0, 24, 25, 28, 32, 33, 2, 8, 14, 15, 18, 20, 22, 23, 29, 30, 31, 33, 8, 9, 13, 14, 15, 18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32 ] edge_index = torch.tensor([row, col]) y = torch.tensor([ # Create communities. 1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0 ]) x = torch.eye(y.size(0), dtype=torch.float) # Select a single training node for each community # (we just use the first one). train_mask = torch.zeros(y.size(0), dtype=torch.bool) for i in range(int(y.max()) + 1): train_mask[(y == i).nonzero(as_tuple=False)[0]] = True data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask) self.data, self.slices = self.collate([data])