torch_geometric.data

class Batch(batch=None, ptr=None, **kwargs)[source]

A plain old python object modeling a batch of graphs as one big (disconnected) graph. With torch_geometric.data.Data being the base class, all its methods can also be used here. In addition, single graphs can be reconstructed via the assignment vector batch, which maps each node to its respective graph identifier.

__getitem__(idx)[source]

Gets the data of the attribute key.

classmethod from_data_list(data_list, follow_batch=[], exclude_keys=[])[source]

Constructs a batch object from a python list holding torch_geometric.data.Data objects. The assignment vector batch is created on the fly. Additionally, creates assignment batch vectors for each key in follow_batch. Will exclude any keys given in exclude_keys.

get_example(idx: int)torch_geometric.data.data.Data[source]

Reconstructs the torch_geometric.data.Data object at index idx from the batch object. The batch object must have been created via from_data_list() in order to be able to reconstruct the initial objects.

property num_graphs

Returns the number of graphs in the batch.

to_data_list()List[torch_geometric.data.data.Data][source]

Reconstructs the list of torch_geometric.data.Data objects from the batch object. The batch object must have been created via from_data_list() in order to be able to reconstruct the initial objects.

class ClusterData(data, num_parts: int, recursive: bool = False, save_dir: Optional[str] = None, log: bool = True)[source]

Clusters/partitions a graph data object into multiple subgraphs, as motivated by the “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks” paper.

Parameters
  • data (torch_geometric.data.Data) – The graph data object.

  • num_parts (int) – The number of partitions.

  • recursive (bool, optional) – If set to True, will use multilevel recursive bisection instead of multilevel k-way partitioning. (default: False)

  • save_dir (string, optional) – If set, will save the partitioned data to the save_dir directory for faster re-use. (default: None)

  • log (bool, optional) – If set to False, will not log any progress. (default: True)

class ClusterLoader(cluster_data, **kwargs)[source]

The data loader scheme from the “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks” paper which merges partioned subgraphs and their between-cluster links from a large-scale graph data object to form a mini-batch.

Note

Use torch_geometric.data.ClusterData and torch_geometric.data.ClusterLoader in conjunction to form mini-batches of clusters. For an example of using Cluster-GCN, see examples/cluster_gcn_reddit.py or examples/cluster_gcn_ppi.py.

Parameters
  • cluster_data (torch_geometric.data.ClusterData) – The already partioned data object.

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as batch_size, shuffle, drop_last or num_workers.

class Data(x=None, edge_index=None, edge_attr=None, y=None, pos=None, normal=None, face=None, **kwargs)[source]

A plain old python object modeling a single graph with various (optional) attributes:

Parameters
  • x (Tensor, optional) – Node feature matrix with shape [num_nodes, num_node_features]. (default: None)

  • edge_index (LongTensor, optional) – Graph connectivity in COO format with shape [2, num_edges]. (default: None)

  • edge_attr (Tensor, optional) – Edge feature matrix with shape [num_edges, num_edge_features]. (default: None)

  • y (Tensor, optional) – Graph or node targets with arbitrary shape. (default: None)

  • pos (Tensor, optional) – Node position matrix with shape [num_nodes, num_dimensions]. (default: None)

  • normal (Tensor, optional) – Normal vector matrix with shape [num_nodes, num_dimensions]. (default: None)

  • face (LongTensor, optional) – Face adjacency matrix with shape [3, num_faces]. (default: None)

The data object is not restricted to these attributes and can be extented by any other additional data.

Example:

data = Data(x=x, edge_index=edge_index)
data.train_idx = torch.tensor([...], dtype=torch.long)
data.test_mask = torch.tensor([...], dtype=torch.bool)
__call__(*keys)[source]

Iterates over all attributes *keys in the data, yielding their attribute names and content. If *keys is not given this method will iterative over all present attributes.

__cat_dim__(key, value)[source]

Returns the dimension for which value of attribute key will get concatenated when creating batches.

Note

This method is for internal use only, and should only be overridden if the batch concatenation process is corrupted for a specific data attribute.

__contains__(key)[source]

Returns True, if the attribute key is present in the data.

__getitem__(key)[source]

Gets the data of the attribute key.

__inc__(key, value)[source]

Returns the incremental count to cumulatively increase the value of the next attribute of key when creating batches.

Note

This method is for internal use only, and should only be overridden if the batch concatenation process is corrupted for a specific data attribute.

__iter__()[source]

Iterates over all present attributes in the data, yielding their attribute names and content.

__len__()[source]

Returns the number of all present attributes.

__setitem__(key, value)[source]

Sets the attribute key to value.

apply(func, *keys)[source]

Applies the function func to all tensor attributes *keys. If *keys is not given, func is applied to all present attributes.

clone()[source]

Performs a deep-copy of the data object.

coalesce()[source]

“Orders and removes duplicated entries from edge indices.

contains_isolated_nodes()[source]

Returns True, if the graph contains isolated nodes.

contains_self_loops()[source]

Returns True, if the graph contains self-loops.

contiguous(*keys)[source]

Ensures a contiguous memory layout for all attributes *keys. If *keys is not given, all present attributes are ensured to have a contiguous memory layout.

cpu(*keys)[source]

Copies all attributes *keys to CPU memory. If *keys is not given, the conversion is applied to all present attributes.

cuda(device=None, non_blocking=False, *keys)[source]

Copies all attributes *keys to CUDA memory. If *keys is not given, the conversion is applied to all present attributes.

classmethod from_dict(dictionary)[source]

Creates a data object from a python dictionary.

is_coalesced()[source]

Returns True, if edge indices are ordered and do not contain duplicate entries.

is_directed()[source]

Returns True, if graph edges are directed.

is_undirected()[source]

Returns True, if graph edges are undirected.

property keys

Returns all names of graph attributes.

property num_edge_features

Returns the number of features per edge in the graph.

property num_edges

Returns the number of edges in the graph. For undirected graphs, this will return the number of bi-directional edges, which is double the amount of unique edges.

property num_faces

Returns the number of faces in the mesh.

property num_features

Alias for num_node_features.

property num_node_features

Returns the number of features per node in the graph.

property num_nodes

Returns or sets the number of nodes in the graph.

Note

The number of nodes in your data object is typically automatically inferred, e.g., when node features x are present. In some cases however, a graph may only be given by its edge indices edge_index. PyTorch Geometric then guesses the number of nodes according to edge_index.max().item() + 1, but in case there exists isolated nodes, this number has not to be correct and can therefore result in unexpected batch-wise behavior. Thus, we recommend to set the number of nodes in your data object explicitly via data.num_nodes = .... You will be given a warning that requests you to do so.

pin_memory(*keys)[source]

Copies all attributes *keys to pinned memory. If *keys is not given, the conversion is applied to all present attributes.

to(device, *keys, **kwargs)[source]

Performs tensor dtype and/or device conversion to all attributes *keys. If *keys is not given, the conversion is applied to all present attributes.

class DataListLoader(dataset, batch_size=1, shuffle=False, **kwargs)[source]

Data loader which merges data objects from a torch_geometric.data.dataset to a python list.

Note

This data loader should be used for multi-gpu support via torch_geometric.nn.DataParallel.

Parameters
  • dataset (Dataset) – The dataset from which to load the data.

  • batch_size (int, optional) – How many samples per batch to load. (default: 1)

  • shuffle (bool, optional) – If set to True, the data will be reshuffled at every epoch (default: False)

class DataLoader(dataset, batch_size=1, shuffle=False, follow_batch=[], exclude_keys=[], **kwargs)[source]

Data loader which merges data objects from a torch_geometric.data.dataset to a mini-batch.

Parameters
  • dataset (Dataset) – The dataset from which to load the data.

  • batch_size (int, optional) – How many samples per batch to load. (default: 1)

  • shuffle (bool, optional) – If set to True, the data will be reshuffled at every epoch. (default: False)

  • follow_batch (list or tuple, optional) – Creates assignment batch vectors for each key in the list. (default: [])

  • exclude_keys (list or tuple, optional) – Will exclude each key in the list. (default: [])

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader.

class Dataset(root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None)[source]

Dataset base class for creating graph datasets. See here for the accompanying tutorial.

Parameters
  • root (string, optional) – Root directory where the dataset should be saved. (optional: None)

  • transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access. (default: None)

  • pre_transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: None)

  • pre_filter (callable, optional) – A function that takes in an torch_geometric.data.Data object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: None)

__getitem__(idx: Union[int, numpy.integer, slice, torch.Tensor, numpy.ndarray, collections.abc.Sequence])Union[torch_geometric.data.dataset.Dataset, torch_geometric.data.data.Data][source]

In case idx is of type integer, will return the data object at index idx (and transforms it in case transform is present). In case idx is a slicing object, e.g., [2:5], a list, a tuple, a PyTorch LongTensor or a BoolTensor, or a numpy np.array, will return a subset of the dataset at the specified indices.

__len__()int[source]

The number of examples in the dataset.

download()[source]

Downloads the dataset to the self.raw_dir folder.

get(idx: int)torch_geometric.data.data.Data[source]

Gets the data object at index idx.

property num_edge_features

Returns the number of features per edge in the dataset.

property num_features

Alias for num_node_features.

property num_node_features

Returns the number of features per node in the dataset.

process()[source]

Processes the dataset to the self.processed_dir folder.

property processed_file_names

The name of the files to find in the self.processed_dir folder in order to skip the processing.

property processed_paths

The filepaths to find in the self.processed_dir folder in order to skip the processing.

property raw_file_names

The name of the files to find in the self.raw_dir folder in order to skip the download.

property raw_paths

The filepaths to find in order to skip the download.

shuffle(return_perm: bool = False)Union[torch_geometric.data.dataset.Dataset, Tuple[torch_geometric.data.dataset.Dataset, torch.Tensor]][source]

Randomly shuffles the examples in the dataset.

Parameters

return_perm (bool, optional) – If set to True, will return the random permutation used to shuffle the dataset in addition. (default: False)

class DenseDataLoader(dataset, batch_size=1, shuffle=False, **kwargs)[source]

Data loader which merges data objects from a torch_geometric.data.dataset to a mini-batch.

Note

To make use of this data loader, all graphs in the dataset needs to have the same shape for each its attributes. Therefore, this data loader should only be used when working with dense adjacency matrices.

Parameters
  • dataset (Dataset) – The dataset from which to load the data.

  • batch_size (int, optional) – How many samples per batch to load. (default: 1)

  • shuffle (bool, optional) – If set to True, the data will be reshuffled at every epoch (default: False)

class GraphSAINTEdgeSampler(data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs)[source]

The GraphSAINT edge sampler class (see torch_geometric.data.GraphSAINTSampler).

class GraphSAINTNodeSampler(data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs)[source]

The GraphSAINT node sampler class (see torch_geometric.data.GraphSAINTSampler).

class GraphSAINTRandomWalkSampler(data, batch_size: int, walk_length: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs)[source]

The GraphSAINT random walk sampler class (see torch_geometric.data.GraphSAINTSampler).

Parameters

walk_length (int) – The length of each random walk.

class GraphSAINTSampler(data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs)[source]

The GraphSAINT sampler base class from the “GraphSAINT: Graph Sampling Based Inductive Learning Method” paper. Given a graph in a data object, this class samples nodes and constructs subgraphs that can be processed in a mini-batch fashion. Normalization coefficients for each mini-batch are given via node_norm and edge_norm data attributes.

Parameters
  • data (torch_geometric.data.Data) – The graph data object.

  • batch_size (int) – The approximate number of samples per batch.

  • num_steps (int, optional) – The number of iterations per epoch. (default: 1)

  • sample_coverage (int) – How many samples per node should be used to compute normalization statistics. (default: 0)

  • save_dir (string, optional) – If set, will save normalization statistics to the save_dir directory for faster re-use. (default: None)

  • log (bool, optional) – If set to False, will not log any pre-processing progress. (default: True)

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as batch_size or num_workers.

class InMemoryDataset(root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None)[source]

Dataset base class for creating graph datasets which fit completely into CPU memory. See here for the accompanying tutorial.

Parameters
  • root (string, optional) – Root directory where the dataset should be saved. (default: None)

  • transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access. (default: None)

  • pre_transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: None)

  • pre_filter (callable, optional) – A function that takes in an torch_geometric.data.Data object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: None)

static collate(data_list: List[torch_geometric.data.data.Data])Tuple[torch_geometric.data.data.Data, Dict[str, torch.Tensor]][source]

Collates a python list of data objects to the internal storage format of torch_geometric.data.InMemoryDataset.

download()[source]

Downloads the dataset to the self.raw_dir folder.

get(idx: int)torch_geometric.data.data.Data[source]

Gets the data object at index idx.

property num_classes

The number of classes in the dataset.

process()[source]

Processes the dataset to the self.processed_dir folder.

property processed_file_names

The name of the files to find in the self.processed_dir folder in order to skip the processing.

property raw_file_names

The name of the files to find in the self.raw_dir folder in order to skip the download.

class NeighborSampler(edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], sizes: List[int], node_idx: Optional[torch.Tensor] = None, num_nodes: Optional[int] = None, return_e_id: bool = True, transform: Optional[Callable] = None, **kwargs)[source]

The neighbor sampler from the “Inductive Representation Learning on Large Graphs” paper, which allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible.

Given a GNN with \(L\) layers and a specific mini-batch of nodes node_idx for which we want to compute embeddings, this module iteratively samples neighbors and constructs bipartite graphs that simulate the actual computation flow of GNNs.

More specifically, sizes denotes how much neighbors we want to sample for each node in each layer. This module then takes in these sizes and iteratively samples sizes[l] for each node involved in layer l. In the next layer, sampling is repeated for the union of nodes that were already encountered. The actual computation graphs are then returned in reverse-mode, meaning that we pass messages from a larger set of nodes to a smaller one, until we reach the nodes for which we originally wanted to compute embeddings.

Hence, an item returned by NeighborSampler holds the current batch_size, the IDs n_id of all nodes involved in the computation, and a list of bipartite graph objects via the tuple (edge_index, e_id, size), where edge_index represents the bipartite edges between source and target nodes, e_id denotes the IDs of original edges in the full graph, and size holds the shape of the bipartite graph. For each bipartite graph, target nodes are also included at the beginning of the list of source nodes so that one can easily apply skip-connections or add self-loops.

Note

For an example of using NeighborSampler, see examples/reddit.py or examples/ogbn_products_sage.py.

Parameters
  • edge_index (Tensor or SparseTensor) – A torch.LongTensor or a torch_sparse.SparseTensor that defines the underlying graph connectivity/message passing flow. edge_index holds the indices of a (sparse) symmetric adjacency matrix. If edge_index is of type torch.LongTensor, its shape must be defined as [2, num_edges], where messages from nodes edge_index[0] are sent to nodes in edge_index[1] (in case flow="source_to_target"). If edge_index is of type torch_sparse.SparseTensor, its sparse indices (row, col) should relate to row = edge_index[1] and col = edge_index[0]. The major difference between both formats is that we need to input the transposed sparse adjacency matrix.

  • sizes ([int]) – The number of neighbors to sample for each node in each layer. If set to sizes[l] = -1, all neighbors are included in layer l.

  • node_idx (LongTensor, optional) – The nodes that should be considered for creating mini-batches. If set to None, all nodes will be considered.

  • num_nodes (int, optional) – The number of nodes in the graph. (default: None)

  • return_e_id (bool, optional) – If set to False, will not return original edge indices of sampled edges. This is only useful in case when operating on graphs without edge features to save memory. (default: True)

  • transform (callable, optional) – A function/transform that takes in an a sampled mini-batch and returns a transformed version. (default: None)

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as batch_size, shuffle, drop_last or num_workers.

class RandomNodeSampler(data, num_parts: int, shuffle: bool = False, **kwargs)[source]

A data loader that randomly samples nodes within a graph and returns their induced subgraph.

Note

For an example of using RandomNodeSampler, see examples/ogbn_proteins_deepgcn.py.

Parameters
  • data (torch_geometric.data.Data) – The graph data object.

  • num_parts (int) – The number of partitions.

  • shuffle (bool, optional) – If set to True, the data is reshuffled at every epoch (default: False).

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as num_workers.

class ShaDowKHopSampler(data: torch_geometric.data.data.Data, depth: int, num_neighbors: int, node_idx: Optional[torch.Tensor] = None, replace: bool = False, **kwargs)[source]

The ShaDow \(k\)-hop sampler from the “Deep Graph Neural Networks with Shallow Subgraph Samplers” paper. Given a graph in a data object, the sampler will create shallow, localized subgraphs. A deep GNN on this local graph then smooths the informative local signals.

Parameters
  • data (torch_geometric.data.Data) – The graph data object.

  • depth (int) – The depth/number of hops of the localized subgraph.

  • num_neighbors (int) – The number of neighbors to sample for each node in each hop.

  • node_idx (LongTensor or BoolTensor, optional) – The nodes that should be considered for creating mini-batches. If set to None, all nodes will be considered.

  • replace (bool, optional) – If set to True, will sample neighbors with replacement. (default: False)

  • **kwargs (optional) – Additional arguments of torch.utils.data.DataLoader, such as batch_size or num_workers.

download_url(url, folder, log=True)[source]

Downloads the content of an URL to a specific folder.

Parameters
  • url (string) – The url.

  • folder (string) – The folder.

  • log (bool, optional) – If False, will not print anything to the console. (default: True)

extract_tar(path, folder, mode='r:gz', log=True)[source]

Extracts a tar archive to a specific folder.

Parameters
  • path (string) – The path to the tar archive.

  • folder (string) – The folder.

  • mode (string, optional) – The compression mode. (default: "r:gz")

  • log (bool, optional) – If False, will not print anything to the console. (default: True)

extract_zip(path, folder, log=True)[source]

Extracts a zip archive to a specific folder.

Parameters
  • path (string) – The path to the tar archive.

  • folder (string) – The folder.

  • log (bool, optional) – If False, will not print anything to the console. (default: True)