from typing import Optional, Callable
import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import barabasi_albert_graph
def house():
edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4],
[1, 3, 4, 4, 2, 0, 1, 3, 2, 0, 0, 1]])
label = torch.tensor([1, 1, 2, 2, 3])
return edge_index, label
[docs]class BAShapes(InMemoryDataset):
r"""The BA-Shapes dataset from the `"GNNExplainer: Generating Explanations
for Graph Neural Networks" <https://arxiv.org/pdf/1903.03894.pdf>`_ paper,
containing a Barabasi-Albert (BA) graph with 300 nodes and a set of 80
"house"-structured graphs connected to it.
Args:
connection_distribution (string, optional): Specifies how the houses
and the BA graph get connected. Valid inputs are :obj:`"random"`
(random BA graph nodes are selected for connection to the houses),
and :obj:`"uniform"` (uniformly distributed BA graph nodes are
selected for connection to the houses). (default: :obj:`"random"`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
"""
def __init__(self, connection_distribution: str = "random",
transform: Optional[Callable] = None):
super().__init__('.', transform)
assert connection_distribution in ['random', 'uniform']
# Build the Barabasi-Albert graph:
num_nodes = 300
edge_index = barabasi_albert_graph(num_nodes, num_edges=5)
edge_label = torch.zeros(edge_index.size(1), dtype=torch.int64)
node_label = torch.zeros(num_nodes, dtype=torch.int64)
# Select nodes to connect shapes:
num_houses = 80
if connection_distribution == 'random':
connecting_nodes = torch.randperm(num_nodes)[:num_houses]
else:
step = num_nodes // num_houses
connecting_nodes = torch.arange(0, num_nodes, step)
# Connect houses to Barabasi-Albert graph:
edge_indices = [edge_index]
edge_labels = [edge_label]
node_labels = [node_label]
for i in range(num_houses):
house_edge_index, house_label = house()
edge_indices.append(house_edge_index + num_nodes)
edge_indices.append(
torch.tensor([[int(connecting_nodes[i]), num_nodes],
[num_nodes, int(connecting_nodes[i])]]))
edge_labels.append(
torch.ones(house_edge_index.size(1), dtype=torch.long))
edge_labels.append(torch.zeros(2, dtype=torch.long))
node_labels.append(house_label)
num_nodes += 5
edge_index = torch.cat(edge_indices, dim=1)
edge_label = torch.cat(edge_labels, dim=0)
node_label = torch.cat(node_labels, dim=0)
x = torch.ones((num_nodes, 10), dtype=torch.float)
expl_mask = torch.zeros(num_nodes, dtype=torch.bool)
expl_mask[torch.arange(400, num_nodes, 5)] = True
data = Data(x=x, edge_index=edge_index, y=node_label,
expl_mask=expl_mask, edge_label=edge_label)
self.data, self.slices = self.collate([data])