torch_geometric.loader

DataLoader

A data loader which merges data objects from a torch_geometric.data.Dataset to a mini-batch.

NeighborLoader

A data loader that performs neighbor sampling as introduced in the “Inductive Representation Learning on Large Graphs” paper.

HGTLoader

The Heterogeneous Graph Sampler from the “Heterogeneous Graph Transformer” paper.

ClusterData

Clusters/partitions a graph data object into multiple subgraphs, as motivated by the “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks” paper.

ClusterLoader

The data loader scheme from the “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks” paper which merges partioned subgraphs and their between-cluster links from a large-scale graph data object to form a mini-batch.

GraphSAINTSampler

The GraphSAINT sampler base class from the “GraphSAINT: Graph Sampling Based Inductive Learning Method” paper.

GraphSAINTNodeSampler

The GraphSAINT node sampler class (see GraphSAINTSampler).

GraphSAINTEdgeSampler

The GraphSAINT edge sampler class (see GraphSAINTSampler).

GraphSAINTRandomWalkSampler

The GraphSAINT random walk sampler class (see GraphSAINTSampler).

ShaDowKHopSampler

The ShaDow \(k\)-hop sampler from the “Deep Graph Neural Networks with Shallow Subgraph Samplers” paper.

RandomNodeSampler

A data loader that randomly samples nodes within a graph and returns their induced subgraph.

DataListLoader

A data loader which batches data objects from a torch_geometric.data.dataset to a Python list.

DenseDataLoader

A data loader which batches data objects from a torch_geometric.data.dataset to a torch_geometric.data.Batch object by stacking all attributes in a new dimension.

NeighborSampler

The neighbor sampler from the “Inductive Representation Learning on Large Graphs” paper, which allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible.

class DataLoader(dataset: Union[torch_geometric.data.dataset.Dataset, List[torch_geometric.data.data.Data], List[torch_geometric.data.hetero_data.HeteroData]], batch_size: int = 1, shuffle: bool = False, follow_batch: List[str] = [], exclude_keys: List[str] = [], **kwargs)[source]

A data loader which merges data objects from a torch_geometric.data.Dataset to a mini-batch. Data objects can be either of type Data or HeteroData.

Parameters
  • dataset (Dataset) – The dataset from which to load the data.

  • batch_size (int, optional) – How many samples per batch to load. (default: 1)

  • shuffle (bool, optional) – If set to True, the data will be reshuffled at every epoch. (default: False)

  • follow_batch (List[str], optional) – Creates assignment batch vectors for each key in the list. (default: [])

  • exclude_keys (List[str], optional) – Will exclude each key in the list. (default: [])

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader.

class NeighborLoader(data: Union[torch_geometric.data.data.Data, torch_geometric.data.hetero_data.HeteroData], num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]], input_nodes: Union[torch.Tensor, None, str, Tuple[str, Optional[torch.Tensor]]] = None, replace: bool = False, directed: bool = True, transform: Optional[Callable] = None, **kwargs)[source]

A data loader that performs neighbor sampling as introduced in the “Inductive Representation Learning on Large Graphs” paper. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible.

More specifically, num_neighbors denotes how much neighbors are sampled for each node in each iteration. NeighborLoader takes in this list of num_neighbors and iteratively samples num_neighbors[i] for each node involved in iteration i - 1.

Sampled nodes are sorted based on the order in which they were sampled. In particular, the first batch_size nodes represent the set of original mini-batch nodes.

from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader

data = Planetoid(path, name='Cora')[0]

loader = NeighborLoader(
    data,
    # Sample 30 neighbors for each node for 2 iterations
    num_neighbors=[30] * 2,
    # Use a batch size of 128 for sampling training nodes
    batch_size=128,
    input_nodes: data.train_mask),
)

sampled_data = next(iter(loader))
print(sampled_data.batch_size)
>>> 128

By default, the data loader will only include the edges that were originally sampled (directed = True). This option should only be used in case the number of hops is equivalent to the number of GNN layers. In case the number of GNN layers is greater than the number of hops, consider setting directed = False, which will include all edges between all sampled nodes (but is slightly slower as a result).

Furthermore, NeighborLoader works for both homogeneous graphs stored via Data as well as heterogeneous graphs stored via HeteroData. When operating in heterogeneous graphs, more fine-grained control over the amount of sampled neighbors of individual edge types is possible, but not necessary:

from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader

hetero_data = OGB_MAG(path)[0]

loader = NeighborLoader(
    hetero_data,
    # Sample 30 neighbors for each node and edge type for 2 iterations
    num_neighbors={key: [30] * 2 for key in hetero_data.edge_types},
    # Use a batch size of 128 for sampling training nodes of type paper
    batch_size=128,
    input_nodes: ('paper', data['paper'].train_mask),
)

sampled_hetero_data = next(iter(loader))
print(sampled_hetero_data['paper'].batch_size)
>>> 128

Note

For an example of using NeighborLoader, see examples/hetero/to_hetero_mag.py.

Parameters
  • data (torch_geometric.data.Data or torch_geometric.data.HeteroData) – The Data or HeteroData graph object.

  • num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]) – The number of neighbors to sample for each node in each iteration. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type.

  • input_nodes (torch.Tensor or str or Tuple[str, torch.Tensor]) – The indices of nodes for which neighbors are sampled to create mini-batches. Needs to be either given as a torch.LongTensor or torch.BoolTensor. If set to None, all nodes will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the node type and node indices. (default: None)

  • replace (bool, optional) – If set to True, will sample with replacement. (default: False)

  • directed (bool, optional) – If set to False, will include all edges between all sampled nodes. (default: True)

  • transform (Callable, optional) – A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: None)

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as batch_size, shuffle, drop_last or num_workers.

class HGTLoader(data: torch_geometric.data.hetero_data.HeteroData, num_samples: Union[List[int], Dict[str, List[int]]], input_nodes: Union[str, Tuple[str, Optional[torch.Tensor]]], transform: Optional[Callable] = None, **kwargs)[source]

The Heterogeneous Graph Sampler from the “Heterogeneous Graph Transformer” paper. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible.

HGTLoader tries to (1) keep a similar number of nodes and edges for each type and (2) keep the sampled sub-graph dense to minimize the information loss and reduce the sample variance.

Methodically, HGTLoader keeps track of a node budget for each node type, which is then used to determine the sampling probability of a node. In particular, the probability of sampling a node is determined by the number of connections to already sampled nodes and their node degrees. With this, HGTLoader will sample a fixed amount of neighbors for each node type in each iteration, as given by the num_samples argument.

Sampled nodes are sorted based on the order in which they were sampled. In particular, the first batch_size nodes represent the set of original mini-batch nodes.

Note

For an example of using HGTLoader, see examples/hetero/to_hetero_mag.py.

from torch_geometric.loader import HGTLoader
from torch_geometric.datasets import OGB_MAG

hetero_data = OGB_MAG(path)[0]

loader = HGTLoader(
    hetero_data,
    # Sample 512 nodes per type and per iteration for 4 iterations
    num_samples={key: [512] * 4 for key in hetero_data.node_types},
    # Use a batch size of 128 for sampling training nodes of type paper
    batch_size=128,
    input_nodes: ('paper': hetero_data['paper'].train_mask),
)

sampled_hetero_data = next(iter(loader))
print(sampled_data.batch_size)
>>> 128
Parameters
  • data (torch_geometric.data.HeteroData) – The HeteroData graph data object.

  • num_samples (List[int] or Dict[str, List[int]]) – The number of nodes to sample in each iteration and for each node type. If given as a list, will sample the same amount of nodes for each node type.

  • input_nodes (str or Tuple[str, torch.Tensor]) – The indices of nodes for which neighbors are sampled to create mini-batches. Needs to be passed as a tuple that holds the node type and corresponding node indices. Node indices need to be either given as a torch.LongTensor or torch.BoolTensor. If node indices are set to None, all nodes of this specific type will be considered.

  • transform (Callable, optional) – A function/transform that takes in an a sampled mini-batch and returns a transformed version. (default: None)

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as batch_size, shuffle, drop_last or num_workers.

class ClusterData(data, num_parts: int, recursive: bool = False, save_dir: Optional[str] = None, log: bool = True)[source]

Clusters/partitions a graph data object into multiple subgraphs, as motivated by the “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks” paper.

Parameters
  • data (torch_geometric.data.Data) – The graph data object.

  • num_parts (int) – The number of partitions.

  • recursive (bool, optional) – If set to True, will use multilevel recursive bisection instead of multilevel k-way partitioning. (default: False)

  • save_dir (string, optional) – If set, will save the partitioned data to the save_dir directory for faster re-use. (default: None)

  • log (bool, optional) – If set to False, will not log any progress. (default: True)

class ClusterLoader(cluster_data, **kwargs)[source]

The data loader scheme from the “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks” paper which merges partioned subgraphs and their between-cluster links from a large-scale graph data object to form a mini-batch.

Note

Use ClusterData and ClusterLoader in conjunction to form mini-batches of clusters. For an example of using Cluster-GCN, see examples/cluster_gcn_reddit.py or examples/cluster_gcn_ppi.py.

Parameters
  • cluster_data (torch_geometric.loader.ClusterData) – The already partioned data object.

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as batch_size, shuffle, drop_last or num_workers.

class GraphSAINTSampler(data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs)[source]

The GraphSAINT sampler base class from the “GraphSAINT: Graph Sampling Based Inductive Learning Method” paper. Given a graph in a data object, this class samples nodes and constructs subgraphs that can be processed in a mini-batch fashion. Normalization coefficients for each mini-batch are given via node_norm and edge_norm data attributes.

Note

See GraphSAINTNodeSampler, GraphSAINTEdgeSampler and GraphSAINTRandomWalkSampler for currently supported samplers. For an example of using GraphSAINT sampling, see examples/graph_saint.py.

Parameters
  • data (torch_geometric.data.Data) – The graph data object.

  • batch_size (int) – The approximate number of samples per batch.

  • num_steps (int, optional) – The number of iterations per epoch. (default: 1)

  • sample_coverage (int) – How many samples per node should be used to compute normalization statistics. (default: 0)

  • save_dir (string, optional) – If set, will save normalization statistics to the save_dir directory for faster re-use. (default: None)

  • log (bool, optional) – If set to False, will not log any pre-processing progress. (default: True)

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as batch_size or num_workers.

class GraphSAINTNodeSampler(data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs)[source]

The GraphSAINT node sampler class (see GraphSAINTSampler).

class GraphSAINTEdgeSampler(data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs)[source]

The GraphSAINT edge sampler class (see GraphSAINTSampler).

class GraphSAINTRandomWalkSampler(data, batch_size: int, walk_length: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs)[source]

The GraphSAINT random walk sampler class (see GraphSAINTSampler).

Parameters

walk_length (int) – The length of each random walk.

class ShaDowKHopSampler(data: torch_geometric.data.data.Data, depth: int, num_neighbors: int, node_idx: Optional[torch.Tensor] = None, replace: bool = False, **kwargs)[source]

The ShaDow \(k\)-hop sampler from the “Deep Graph Neural Networks with Shallow Subgraph Samplers” paper. Given a graph in a data object, the sampler will create shallow, localized subgraphs. A deep GNN on this local graph then smooths the informative local signals.

Parameters
  • data (torch_geometric.data.Data) – The graph data object.

  • depth (int) – The depth/number of hops of the localized subgraph.

  • num_neighbors (int) – The number of neighbors to sample for each node in each hop.

  • node_idx (LongTensor or BoolTensor, optional) – The nodes that should be considered for creating mini-batches. If set to None, all nodes will be considered.

  • replace (bool, optional) – If set to True, will sample neighbors with replacement. (default: False)

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as batch_size or num_workers.

class RandomNodeSampler(data, num_parts: int, shuffle: bool = False, **kwargs)[source]

A data loader that randomly samples nodes within a graph and returns their induced subgraph.

Note

For an example of using RandomNodeSampler, see examples/ogbn_proteins_deepgcn.py.

Parameters
  • data (torch_geometric.data.Data) – The graph data object.

  • num_parts (int) – The number of partitions.

  • shuffle (bool, optional) – If set to True, the data is reshuffled at every epoch (default: False).

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as num_workers.

class DataListLoader(dataset: Union[torch_geometric.data.dataset.Dataset, List[torch_geometric.data.data.Data], List[torch_geometric.data.hetero_data.HeteroData]], batch_size: int = 1, shuffle: bool = False, **kwargs)[source]

A data loader which batches data objects from a torch_geometric.data.dataset to a Python list. Data objects can be either of type Data or HeteroData.

Note

This data loader should be used for multi-GPU support via torch_geometric.nn.DataParallel.

Parameters
  • dataset (Dataset) – The dataset from which to load the data.

  • batch_size (int, optional) – How many samples per batch to load. (default: 1)

  • shuffle (bool, optional) – If set to True, the data will be reshuffled at every epoch. (default: False)

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as drop_last or num_workers.

class DenseDataLoader(dataset: Union[torch_geometric.data.dataset.Dataset, List[torch_geometric.data.data.Data]], batch_size: int = 1, shuffle: bool = False, **kwargs)[source]

A data loader which batches data objects from a torch_geometric.data.dataset to a torch_geometric.data.Batch object by stacking all attributes in a new dimension.

Note

To make use of this data loader, all graph attributes in the dataset need to have the same shape. In particular, this data loader should only be used when working with dense adjacency matrices.

Parameters
  • dataset (Dataset) – The dataset from which to load the data.

  • batch_size (int, optional) – How many samples per batch to load. (default: 1)

  • shuffle (bool, optional) – If set to True, the data will be reshuffled at every epoch. (default: False)

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as drop_last or num_workers.

class NeighborSampler(edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], sizes: List[int], node_idx: Optional[torch.Tensor] = None, num_nodes: Optional[int] = None, return_e_id: bool = True, transform: Optional[Callable] = None, **kwargs)[source]

The neighbor sampler from the “Inductive Representation Learning on Large Graphs” paper, which allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible.

Given a GNN with \(L\) layers and a specific mini-batch of nodes node_idx for which we want to compute embeddings, this module iteratively samples neighbors and constructs bipartite graphs that simulate the actual computation flow of GNNs.

More specifically, sizes denotes how much neighbors we want to sample for each node in each layer. This module then takes in these sizes and iteratively samples sizes[l] for each node involved in layer l. In the next layer, sampling is repeated for the union of nodes that were already encountered. The actual computation graphs are then returned in reverse-mode, meaning that we pass messages from a larger set of nodes to a smaller one, until we reach the nodes for which we originally wanted to compute embeddings.

Hence, an item returned by NeighborSampler holds the current batch_size, the IDs n_id of all nodes involved in the computation, and a list of bipartite graph objects via the tuple (edge_index, e_id, size), where edge_index represents the bipartite edges between source and target nodes, e_id denotes the IDs of original edges in the full graph, and size holds the shape of the bipartite graph. For each bipartite graph, target nodes are also included at the beginning of the list of source nodes so that one can easily apply skip-connections or add self-loops.

Warning

NeighborSampler is deprecated and will be removed in a future release. Use torch_geometric.loader.NeighborLoader instead.

Note

For an example of using NeighborSampler, see examples/reddit.py or examples/ogbn_products_sage.py.

Parameters
  • edge_index (Tensor or SparseTensor) – A torch.LongTensor or a torch_sparse.SparseTensor that defines the underlying graph connectivity/message passing flow. edge_index holds the indices of a (sparse) symmetric adjacency matrix. If edge_index is of type torch.LongTensor, its shape must be defined as [2, num_edges], where messages from nodes edge_index[0] are sent to nodes in edge_index[1] (in case flow="source_to_target"). If edge_index is of type torch_sparse.SparseTensor, its sparse indices (row, col) should relate to row = edge_index[1] and col = edge_index[0]. The major difference between both formats is that we need to input the transposed sparse adjacency matrix.

  • sizes ([int]) – The number of neighbors to sample for each node in each layer. If set to sizes[l] = -1, all neighbors are included in layer l.

  • node_idx (LongTensor, optional) – The nodes that should be considered for creating mini-batches. If set to None, all nodes will be considered.

  • num_nodes (int, optional) – The number of nodes in the graph. (default: None)

  • return_e_id (bool, optional) – If set to False, will not return original edge indices of sampled edges. This is only useful in case when operating on graphs without edge features to save memory. (default: True)

  • transform (callable, optional) – A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: None)

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as batch_size, shuffle, drop_last or num_workers.