Source code for torch_geometric.datasets.explainer_dataset

from typing import Any, Callable, Dict, Optional, Union

import torch
from torch import Tensor

from torch_geometric.data import InMemoryDataset
from torch_geometric.datasets.graph_generator import GraphGenerator
from torch_geometric.datasets.motif_generator import MotifGenerator
from torch_geometric.explain import Explanation


[docs]class ExplainerDataset(InMemoryDataset): r"""Generates a synthetic dataset for evaluating explainabilty algorithms, as described in the `"GNNExplainer: Generating Explanations for Graph Neural Networks" <https://arxiv.org/abs/1903.03894>`__ paper. The :class:`~torch_geometric.datasets.ExplainerDataset` creates synthetic graphs coming from a :class:`~torch_geometric.datasets.graph_generator.GraphGenerator`, and randomly attaches :obj:`num_motifs` many motifs to it coming from a :class:`~torch_geometric.datasets.graph_generator.MotifGenerator`. Ground-truth node-level and edge-level explainabilty masks are given based on whether nodes and edges are part of a certain motif or not. For example, to generate a random Barabasi-Albert (BA) graph with 300 nodes, in which we want to randomly attach 80 :obj:`"house"` motifs, write: .. code-block:: python from torch_geometric.datasets import ExplainerDataset from torch_geometric.datasets.graph_generator import BAGraph dataset = ExplainerDataset( graph_generator=BAGraph(num_nodes=300, num_edges=5), motif_generator='house', num_motifs=80, ) .. note:: For an example of using :class:`ExplainerDataset`, see `examples/explain/gnn_explainer_ba_shapes.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ /explain/gnn_explainer_ba_shapes.py>`_. Args: graph_generator (GraphGenerator or str): The graph generator to be used, *e.g.*, :class:`torch.geometric.datasets.graph_generator.BAGraph` (or any string that automatically resolves to it). motif_generator (MotifGenerator): The motif generator to be used, *e.g.*, :class:`torch_geometric.datasets.motif_generator.HouseMotif` (or any string that automatically resolves to it). num_motifs (int): The number of motifs to attach to the graph. num_graphs (int, optional): The number of graphs to generate. (default: :obj:`1`) graph_generator_kwargs (Dict[str, Any], optional): Arguments passed to the respective graph generator module in case it gets automatically resolved. (default: :obj:`None`) motif_generator_kwargs (Dict[str, Any], optional): Arguments passed to the respective motif generator module in case it gets automatically resolved. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) """ def __init__( self, graph_generator: Union[GraphGenerator, str], motif_generator: Union[MotifGenerator, str], num_motifs: int, num_graphs: int = 1, graph_generator_kwargs: Optional[Dict[str, Any]] = None, motif_generator_kwargs: Optional[Dict[str, Any]] = None, transform: Optional[Callable] = None, ): super().__init__(root=None, transform=transform) if num_motifs <= 0: raise ValueError(f"At least one motif needs to be attached to the " f"graph (got {num_motifs})") self.graph_generator = GraphGenerator.resolve( graph_generator, **(graph_generator_kwargs or {}), ) self.motif_generator = MotifGenerator.resolve( motif_generator, **(motif_generator_kwargs or {}), ) self.num_motifs = num_motifs # TODO (matthias) support on-the-fly graph generation. data_list = [self.get_graph() for _ in range(num_graphs)] self.data, self.slices = self.collate(data_list) def get_graph(self) -> Explanation: data = self.graph_generator() assert data.num_nodes is not None assert data.edge_index is not None edge_indices = [data.edge_index] num_nodes = data.num_nodes node_masks = [torch.zeros(data.num_nodes)] edge_masks = [torch.zeros(data.num_edges)] ys = [torch.zeros(num_nodes, dtype=torch.long)] connecting_nodes = torch.randperm(num_nodes)[:self.num_motifs] for i in connecting_nodes.tolist(): motif = self.motif_generator() assert motif.num_nodes is not None assert motif.edge_index is not None # Add motif to the graph. edge_indices.append(motif.edge_index + num_nodes) node_masks.append(torch.ones(motif.num_nodes)) edge_masks.append(torch.ones(motif.num_edges)) # Add random motif connection to the graph. j = int(torch.randint(0, motif.num_nodes, (1, ))) + num_nodes edge_indices.append(torch.tensor([[i, j], [j, i]])) edge_masks.append(torch.zeros(2)) if isinstance(motif.y, Tensor): ys.append(motif.y + 1 if motif.y.min() == 0 else motif.y) else: ys.append(torch.ones(motif.num_nodes, dtype=torch.long)) num_nodes += motif.num_nodes return Explanation( edge_index=torch.cat(edge_indices, dim=1), y=torch.cat(ys, dim=0), edge_mask=torch.cat(edge_masks, dim=0), node_mask=torch.cat(node_masks, dim=0), ) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'graph_generator={self.graph_generator}, ' f'motif_generator={self.motif_generator}, ' f'num_motifs={self.num_motifs})')