torch_geometric.datasets.InfectionDataset
- class InfectionDataset(graph_generator: Union[GraphGenerator, str], num_infected_nodes: Union[int, List[int]], max_path_length: Union[int, List[int]], num_graphs: Optional[int] = None, graph_generator_kwargs: Optional[Dict[str, Any]] = None, transform: Optional[Callable] = None)[source]
Bases:
InMemoryDataset
Generates a synthetic infection dataset for evaluating explainabilty algorithms, as described in the “Explainability Techniques for Graph Convolutional Networks” paper. The
InfectionDataset
creates synthetic graphs coming from aGraphGenerator
withnum_infected
randomly assigned infected nodes. The dataset describes a node classification task of predicting the length of the shortest path to infected nodes, with corresponding ground-truth edge-level masks.For example, to generate a random Erdos-Renyi (ER) infection graph with
500
nodes and0.004
edge probability, write:from torch_geometric.datasets import InfectionDataset from torch_geometric.datasets.graph_generator import ERGraph dataset = InfectionDataset( graph_generator=ERGraph(num_nodes=500, edge_prob=0.004), num_infected_nodes=50, max_path_length=3, )
- Parameters:
graph_generator (GraphGenerator or str) – The graph generator to be used, e.g.,
torch.geometric.datasets.graph_generator.BAGraph
(or any string that automatically resolves to it).num_infected_nodes (int or List[int]) – The number of randomly selected infected nodes in the graph. If given as a list, will select a different number of infected nodes for different graphs.
max_path_length (int, List[int]) – The maximum shortest path length to determine whether a node will be infected. If given as a list, will apply different shortest path lengths for different graphs. (default:
5
)num_graphs (int, optional) – The number of graphs to generate. The number of graphs will be automatically determined by
len(num_infected_nodes)
orlen(max_path_length)
in case either of them is given as a list, and should only be set in case one wants to create multiple graphs whilenum_infected_nodes
andmax_path_length
are given as an integer. (default:None
)graph_generator_kwargs (Dict[str, Any], optional) – Arguments passed to the respective graph generator module in case it gets automatically resolved. (default:
None
)transform (callable, optional) – A function/transform that takes in an
torch_geometric.data.Data
object and returns a transformed version. The data object will be transformed before every access. (default:None
)