torch_geometric.datasets.ExplainerDataset
- class ExplainerDataset(graph_generator: Union[GraphGenerator, str], motif_generator: Union[MotifGenerator, str], num_motifs: int, num_graphs: int = 1, graph_generator_kwargs: Optional[Dict[str, Any]] = None, motif_generator_kwargs: Optional[Dict[str, Any]] = None, transform: Optional[Callable] = None)[source]
Bases:
InMemoryDataset
Generates a synthetic dataset for evaluating explainabilty algorithms, as described in the “GNNExplainer: Generating Explanations for Graph Neural Networks” paper. The
ExplainerDataset
creates synthetic graphs coming from aGraphGenerator
, and randomly attachesnum_motifs
many motifs to it coming from aMotifGenerator
. Ground-truth node-level and edge-level explainabilty masks are given based on whether nodes and edges are part of a certain motif or not.For example, to generate a random Barabasi-Albert (BA) graph with 300 nodes, in which we want to randomly attach 80
"house"
motifs, write:from torch_geometric.datasets import ExplainerDataset from torch_geometric.datasets.graph_generator import BAGraph dataset = ExplainerDataset( graph_generator=BAGraph(num_nodes=300, num_edges=5), motif_generator='house', num_motifs=80, )
Note
For an example of using
ExplainerDataset
, see examples/explain/gnn_explainer_ba_shapes.py.- Parameters:
graph_generator (GraphGenerator or str) – The graph generator to be used, e.g.,
torch.geometric.datasets.graph_generator.BAGraph
(or any string that automatically resolves to it).motif_generator (MotifGenerator) – The motif generator to be used, e.g.,
torch_geometric.datasets.motif_generator.HouseMotif
(or any string that automatically resolves to it).num_motifs (int) – The number of motifs to attach to the graph.
num_graphs (int, optional) – The number of graphs to generate. (default:
1
)graph_generator_kwargs (Dict[str, Any], optional) – Arguments passed to the respective graph generator module in case it gets automatically resolved. (default:
None
)motif_generator_kwargs (Dict[str, Any], optional) – Arguments passed to the respective motif generator module in case it gets automatically resolved. (default:
None
)transform (callable, optional) – A function/transform that takes in an
torch_geometric.data.Data
object and returns a transformed version. The data object will be transformed before every access. (default:None
)