torch_geometric.datasets.InfectionDataset

class InfectionDataset(graph_generator: Union[GraphGenerator, str], num_infected_nodes: Union[int, List[int]], max_path_length: Union[int, List[int]], num_graphs: Optional[int] = None, graph_generator_kwargs: Optional[Dict[str, Any]] = None, transform: Optional[Callable] = None)[source]

Bases: InMemoryDataset

Generates a synthetic infection dataset for evaluating explainabilty algorithms, as described in the “Explainability Techniques for Graph Convolutional Networks” paper. The InfectionDataset creates synthetic graphs coming from a GraphGenerator with num_infected randomly assigned infected nodes. The dataset describes a node classification task of predicting the length of the shortest path to infected nodes, with corresponding ground-truth edge-level masks.

For example, to generate a random Erdos-Renyi (ER) infection graph with 500 nodes and 0.004 edge probability, write:

from torch_geometric.datasets import InfectionDataset
from torch_geometric.datasets.graph_generator import ERGraph

dataset = InfectionDataset(
    graph_generator=ERGraph(num_nodes=500, edge_prob=0.004),
    num_infected_nodes=50,
    max_path_length=3,
)
Parameters:
  • graph_generator (GraphGenerator or str) – The graph generator to be used, e.g., torch.geometric.datasets.graph_generator.BAGraph (or any string that automatically resolves to it).

  • num_infected_nodes (int or List[int]) – The number of randomly selected infected nodes in the graph. If given as a list, will select a different number of infected nodes for different graphs.

  • max_path_length (int, List[int]) – The maximum shortest path length to determine whether a node will be infected. If given as a list, will apply different shortest path lengths for different graphs. (default: 5)

  • num_graphs (int, optional) – The number of graphs to generate. The number of graphs will be automatically determined by len(num_infected_nodes) or len(max_path_length) in case either of them is given as a list, and should only be set in case one wants to create multiple graphs while num_infected_nodes and max_path_length are given as an integer. (default: None)

  • graph_generator_kwargs (Dict[str, Any], optional) – Arguments passed to the respective graph generator module in case it gets automatically resolved. (default: None)

  • transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access. (default: None)