torch_geometric.datasets.ExplainerDataset

class ExplainerDataset(graph_generator: Union[GraphGenerator, str], motif_generator: Union[MotifGenerator, str], num_motifs: int, num_graphs: int = 1, graph_generator_kwargs: Optional[Dict[str, Any]] = None, motif_generator_kwargs: Optional[Dict[str, Any]] = None, transform: Optional[Callable] = None)[source]

Bases: InMemoryDataset

Generates a synthetic dataset for evaluating explainabilty algorithms, as described in the “GNNExplainer: Generating Explanations for Graph Neural Networks” paper. The ExplainerDataset creates synthetic graphs coming from a GraphGenerator, and randomly attaches num_motifs many motifs to it coming from a MotifGenerator. Ground-truth node-level and edge-level explainabilty masks are given based on whether nodes and edges are part of a certain motif or not.

For example, to generate a random Barabasi-Albert (BA) graph with 300 nodes, in which we want to randomly attach 80 "house" motifs, write:

from torch_geometric.datasets import ExplainerDataset
from torch_geometric.datasets.graph_generator import BAGraph

dataset = ExplainerDataset(
    graph_generator=BAGraph(num_nodes=300, num_edges=5),
    motif_generator='house',
    num_motifs=80,
)
Parameters:
  • graph_generator (GraphGenerator or str) – The graph generator to be used, e.g., torch.geometric.datasets.graph_generator.BAGraph (or any string that automatically resolves to it).

  • motif_generator (MotifGenerator) – The motif generator to be used, e.g., torch_geometric.datasets.motif_generator.HouseMotif (or any string that automatically resolves to it).

  • num_motifs (int) – The number of motifs to attach to the graph.

  • num_graphs (int, optional) – The number of graphs to generate. (default: 1)

  • graph_generator_kwargs (Dict[str, Any], optional) – Arguments passed to the respective graph generator module in case it gets automatically resolved. (default: None)

  • motif_generator_kwargs (Dict[str, Any], optional) – Arguments passed to the respective motif generator module in case it gets automatically resolved. (default: None)

  • transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access. (default: None)