Source code for torch_geometric.datasets.ba_multi_shapes

import pickle
from typing import Callable, List, Optional

import numpy as np
import torch

from torch_geometric.data import Data, InMemoryDataset, download_url


[docs]class BAMultiShapesDataset(InMemoryDataset): r"""The synthetic BA-Multi-Shapes graph classification dataset for evaluating explainabilty algorithms, as described in the `"Global Explainability of GNNs via Logic Combination of Learned Concepts" <https://arxiv.org/abs/2210.07147>`_ paper. Given three atomic motifs, namely House (H), Wheel (W), and Grid (G), :class:`~torch_geometric.datasets.BAMultiShapesDataset` contains 1,000 graphs where each graph is obtained by attaching the motifs to a random Barabasi-Albert (BA) as follows: * class 0: :math:`\emptyset \lor H \lor W \lor G \lor \{ H, W, G \}` * class 1: :math:`(H \land W) \lor (H \land G) \lor (W \land G)` This dataset is pre-computed from the official implementation. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #classes * - 1000 - 40 - ~87.0 - 10 - 2 """ url = ('https://github.com/steveazzolin/gnn_logic_global_expl/raw/master/' 'datasets/BAMultiShapes/BAMultiShapes.pkl') def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> str: return 'BAMultiShapes.pkl' @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: download_url(self.url, self.raw_dir) def process(self) -> None: with open(self.raw_paths[0], 'rb') as f: adjs, xs, ys = pickle.load(f) data_list: List[Data] = [] for adj, x, y in zip(adjs, xs, ys): edge_index = torch.from_numpy(adj).nonzero().t() x = torch.from_numpy(np.array(x)).to(torch.float) data = Data(x=x, edge_index=edge_index, y=y) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0])