from typing import Callable, Optional
import torch
from torch_geometric.data import Data, InMemoryDataset
[docs]class KarateClub(InMemoryDataset):
r"""Zachary's karate club network from the `"An Information Flow Model for
Conflict and Fission in Small Groups"
<https://www.journals.uchicago.edu/doi/abs/10.1086/jar.33.4.3629752>`_
paper, containing 34 nodes,
connected by 156 (undirected and unweighted) edges.
Every node is labeled by one of four classes obtained via modularity-based
clustering, following the `"Semi-supervised Classification with Graph
Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_ paper.
Training is based on a single labeled example per class, *i.e.* a total
number of 4 labeled nodes.
Args:
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
**STATS:**
.. list-table::
:widths: 10 10 10 10
:header-rows: 1
* - #nodes
- #edges
- #features
- #classes
* - 34
- 156
- 34
- 4
"""
def __init__(self, transform: Optional[Callable] = None):
super().__init__(None, transform)
row = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4,
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 10, 10,
10, 11, 12, 12, 13, 13, 13, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17,
18, 18, 19, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 23, 23, 23, 24,
24, 24, 25, 25, 25, 26, 26, 27, 27, 27, 27, 28, 28, 28, 29, 29, 29,
29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32,
32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33,
33, 33, 33, 33, 33, 33
]
col = [
1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 17, 19, 21, 31, 0, 2, 3, 7,
13, 17, 19, 21, 30, 0, 1, 3, 7, 8, 9, 13, 27, 28, 32, 0, 1, 2, 7,
12, 13, 0, 6, 10, 0, 6, 10, 16, 0, 4, 5, 16, 0, 1, 2, 3, 0, 2, 30,
32, 33, 2, 33, 0, 4, 5, 0, 0, 3, 0, 1, 2, 3, 33, 32, 33, 32, 33, 5,
6, 0, 1, 32, 33, 0, 1, 33, 32, 33, 0, 1, 32, 33, 25, 27, 29, 32,
33, 25, 27, 31, 23, 24, 31, 29, 33, 2, 23, 24, 33, 2, 31, 33, 23,
26, 32, 33, 1, 8, 32, 33, 0, 24, 25, 28, 32, 33, 2, 8, 14, 15, 18,
20, 22, 23, 29, 30, 31, 33, 8, 9, 13, 14, 15, 18, 19, 20, 22, 23,
26, 27, 28, 29, 30, 31, 32
]
edge_index = torch.tensor([row, col])
y = torch.tensor([ # Create communities.
1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1,
0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0
])
x = torch.eye(y.size(0), dtype=torch.float)
# Select a single training node for each community
# (we just use the first one).
train_mask = torch.zeros(y.size(0), dtype=torch.bool)
for i in range(int(y.max()) + 1):
train_mask[(y == i).nonzero(as_tuple=False)[0]] = True
data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask)
self.data, self.slices = self.collate([data])