torch_geometric.data

Data

A data object describing a homogeneous graph.

HeteroData

A data object describing a heterogeneous graph, holding multiple node and/or edge types in disjunct storage objects.

TemporalData

Batch

A data object describing a batch of graphs as one big (disconnected) graph.

Dataset

Dataset base class for creating graph datasets.

InMemoryDataset

Dataset base class for creating graph datasets which easily fit into CPU memory.

download_url

Downloads the content of an URL to a specific folder.

extract_tar

Extracts a tar archive to a specific folder.

extract_zip

Extracts a zip archive to a specific folder.

extract_bz2

Extracts a bz2 archive to a specific folder.

extract_gz

Extracts a gz archive to a specific folder.

class Data(x: Optional[torch.Tensor] = None, edge_index: Optional[torch.Tensor] = None, edge_attr: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, pos: Optional[torch.Tensor] = None, **kwargs)[source]

A data object describing a homogeneous graph. The data object can hold node-level, link-level and graph-level attributes. In general, Data tries to mimic the behaviour of a regular Python dictionary. In addition, it provides useful functionality for analyzing graph structures, and provides basic PyTorch tensor functionalities. See here for the accompanying tutorial.

from torch_geometric.data import Data

data = Data(x=x, edge_index=edge_index, ...)

# Add additional arguments to `data`:
data.train_idx = torch.tensor([...], dtype=torch.long)
data.test_mask = torch.tensor([...], dtype=torch.bool)

# Analyzing the graph structure:
data.num_nodes
>>> 23

data.is_directed()
>>> False

# PyTorch tensor functionality:
data = data.pin_memory()
data = data.to('cuda:0', non_blocking=True)
Parameters
  • x (Tensor, optional) – Node feature matrix with shape [num_nodes, num_node_features]. (default: None)

  • edge_index (LongTensor, optional) – Graph connectivity in COO format with shape [2, num_edges]. (default: None)

  • edge_attr (Tensor, optional) – Edge feature matrix with shape [num_edges, num_edge_features]. (default: None)

  • y (Tensor, optional) – Graph-level or node-level ground-truth labels with arbitrary shape. (default: None)

  • pos (Tensor, optional) – Node position matrix with shape [num_nodes, num_dimensions]. (default: None)

  • **kwargs (optional) – Additional attributes.

to_dict()Dict[str, Any][source]

Returns a dictionary of stored key/value pairs.

to_namedtuple()NamedTuple[source]

Returns a NamedTuple of stored key/value pairs.

__cat_dim__(key: str, value: Any, *args, **kwargs)Any[source]

Returns the dimension for which the value value of the attribute key will get concatenated when creating mini-batches using torch_geometric.loader.DataLoader.

Note

This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute.

__inc__(key: str, value: Any, *args, **kwargs)Any[source]

Returns the incremental count to cumulatively increase the value value of the attribute key when creating mini-batches using torch_geometric.loader.DataLoader.

Note

This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute.

classmethod from_dict(mapping: Dict[str, Any])[source]

Creates a Data object from a Python dictionary.

property num_node_features: int

Returns the number of features per node in the graph.

property num_features: int

Returns the number of features per node in the graph. Alias for num_node_features.

property num_edge_features: int

Returns the number of features per edge in the graph.

apply(func: Callable, *args: List[str])

Applies the function func, either to all attributes or only the ones given in *args.

apply_(func: Callable, *args: List[str])

Applies the in-place function func, either to all attributes or only the ones given in *args.

clone()

Performs a deep-copy of the data object.

coalesce()

Sorts and removes duplicated entries from edge indices edge_index.

contiguous(*args: List[str])

Ensures a contiguous memory layout, either for all attributes or only the ones given in *args.

cpu(*args: List[str])

Copies attributes to CPU memory, either for all attributes or only the ones given in *args.

cuda(device: Optional[Union[int, str]] = None, *args: List[str], non_blocking: bool = False)

Copies attributes to CUDA memory, either for all attributes or only the ones given in *args.

detach(*args: List[str])

Detaches attributes from the computation graph by creating a new tensor, either for all attributes or only the ones given in *args.

detach_(*args: List[str])

Detaches attributes from the computation graph, either for all attributes or only the ones given in *args.

has_isolated_nodes()bool

Returns True if the graph contains isolated nodes.

has_self_loops()bool

Returns True if the graph contains self-loops.

is_coalesced()bool

Returns True if edge indices edge_index are sorted and do not contain duplicate entries.

is_directed()bool

Returns True if graph edges are directed.

is_undirected()bool

Returns True if graph edges are undirected.

property keys: List[str]

Returns a list of all graph attribute names.

property num_edges: int

Returns the number of edges in the graph. For undirected graphs, this will return the number of bi-directional edges, which is double the amount of unique edges.

property num_nodes: Optional[int]

Returns the number of nodes in the graph.

Note

The number of nodes in the data object is automatically inferred in case node-level attributes are present, e.g., data.x. In some cases, however, a graph may only be given without any node-level attributes. PyG then guesses the number of nodes according to edge_index.max().item() + 1. However, in case there exists isolated nodes, this number does not have to be correct which can result in unexpected behaviour. Thus, we recommend to set the number of nodes in your data object explicitly via data.num_nodes = .... You will be given a warning that requests you to do so.

pin_memory(*args: List[str])

Copies attributes to pinned memory, either for all attributes or only the ones given in *args.

record_stream(stream: torch.cuda.streams.Stream, *args: List[str])

Ensures that the tensor memory is not reused for another tensor until all current work queued on stream has been completed, either for all attributes or only the ones given in *args.

requires_grad_(*args: List[str], requires_grad: bool = True)

Tracks gradient computation, either for all attributes or only the ones given in *args.

share_memory_(*args: List[str])

Moves attributes to shared memory, either for all attributes or only the ones given in *args.

size(dim: Optional[int] = None)Optional[Union[Tuple[Optional[int], Optional[int]], int]]

Returns the size of the adjacency matrix induced by the graph.

to(device: Union[int, str], *args: List[str], non_blocking: bool = False)

Performs tensor device conversion, either for all attributes or only the ones given in *args.

property num_faces: Optional[int]

Returns the number of faces in the mesh.

class HeteroData(_mapping: Optional[Dict[str, Any]] = None, **kwargs)[source]

A data object describing a heterogeneous graph, holding multiple node and/or edge types in disjunct storage objects. Storage objects can hold either node-level, link-level or graph-level attributes. In general, HeteroData tries to mimic the behaviour of a regular nested Python dictionary. In addition, it provides useful functionality for analyzing graph structures, and provides basic PyTorch tensor functionalities.

from torch_geometric.data import HeteroData

data = HeteroData()

# Create two node types "paper" and "author" holding a feature matrix:
data['paper'].x = torch.randn(num_papers, num_paper_features)
data['author'].x = torch.randn(num_authors, num_authors_features)

# Create an edge type "(author, writes, paper)" and building the
# graph connectivity:
data['author', 'writes', 'paper'].edge_index = ...  # [2, num_edges]

data['paper'].num_nodes
>>> 23

data['author', 'writes', 'paper'].num_edges
>>> 52

# PyTorch tensor functionality:
data = data.pin_memory()
data = data.to('cuda:0', non_blocking=True)

Note that there exists multiple ways to create a heterogeneous graph data, e.g.:

  • To initialize a node of type "paper" holding a node feature matrix x_paper named x:

    from torch_geometric.data import HeteroData
    
    data = HeteroData()
    data['paper'].x = x_paper
    
    data = HeteroData(paper={ 'x': x_paper })
    
    data = HeteroData({'paper': { 'x': x_paper }})
    
  • To initialize an edge from source node type "author" to destination node type "paper" with relation type "writes" holding a graph connectivity matrix edge_index_author_paper named edge_index:

    data = HeteroData()
    data['author', 'writes', 'paper'].edge_index = edge_index_author_paper
    
    data = HeteroData(author__writes__paper={
        'edge_index': edge_index_author_paper
    })
    
    data = HeteroData({
        ('author', 'writes', 'paper'):
        { 'edge_index': edge_index_author_paper }
    })
    
property stores: List[torch_geometric.data.storage.BaseStorage]

Returns a list of all storages of the graph.

property node_types: List[str]

Returns a list of all node types of the graph.

property node_stores: List[torch_geometric.data.storage.NodeStorage]

Returns a list of all node storages of the graph.

property edge_types: List[Tuple[str, str, str]]

Returns a list of all edge types of the graph.

property edge_stores: List[torch_geometric.data.storage.EdgeStorage]

Returns a list of all edge storages of the graph.

to_dict()Dict[str, Any][source]

Returns a dictionary of stored key/value pairs.

to_namedtuple()NamedTuple[source]

Returns a NamedTuple of stored key/value pairs.

__cat_dim__(key: str, value: Any, store: Optional[Union[torch_geometric.data.storage.NodeStorage, torch_geometric.data.storage.EdgeStorage]] = None, *args, **kwargs)Any[source]

Returns the dimension for which the value value of the attribute key will get concatenated when creating mini-batches using torch_geometric.loader.DataLoader.

Note

This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute.

__inc__(key: str, value: Any, store: Optional[Union[torch_geometric.data.storage.NodeStorage, torch_geometric.data.storage.EdgeStorage]] = None, *args, **kwargs)Any[source]

Returns the incremental count to cumulatively increase the value value of the attribute key when creating mini-batches using torch_geometric.loader.DataLoader.

Note

This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute.

property num_nodes: Optional[int]

Returns the number of nodes in the graph.

metadata()Tuple[List[str], List[Tuple[str, str, str]]][source]

Returns the heterogeneous meta-data, i.e. its node and edge types.

data = HeteroData()
data['paper'].x = ...
data['author'].x = ...
data['author', 'writes', 'paper'].edge_index = ...

print(data.metadata())
>>> (['paper', 'author'], [('author', 'writes', 'paper')])
collect(key: str)Dict[Union[str, Tuple[str, str, str]], Any][source]

Collects the attribute key from all node and edge types.

data = HeteroData()
data['paper'].x = ...
data['author'].x = ...

print(data.collect('x'))
>>> { 'paper': ..., 'author': ...}

Note

This is equivalent to writing data.x_dict.

get_node_store(key: str)torch_geometric.data.storage.NodeStorage[source]

Gets the NodeStorage object of a particular node type key. If the storage is not present yet, will create a new torch_geometric.data.storage.NodeStorage object for the given node type.

data = HeteroData()
node_storage = data.get_node_store('paper')
get_edge_store(src: str, rel: str, dst: str)torch_geometric.data.storage.EdgeStorage[source]

Gets the EdgeStorage object of a particular edge type given by the tuple (src, rel, dst). If the storage is not present yet, will create a new torch_geometric.data.storage.EdgeStorage object for the given edge type.

data = HeteroData()
edge_storage = data.get_edge_store('author', 'writes', 'paper')
to_homogeneous(node_attrs: Optional[List[str]] = None, edge_attrs: Optional[List[str]] = None, add_edge_type: bool = True)torch_geometric.data.data.Data[source]

Converts a HeteroData object to a homogeneous Data object. By default, all features with same feature dimensionality across different types will be merged into a single representation. Furthermore, an attribute named edge_type will be added to the returned Data object, denoting an edge-level vector holding the edge type of each edge as an integer.

Parameters
  • node_attrs (List[str], optional) – The node features to combine across all node types. These node features need to be of the same feature dimensionality. If set to None, will automatically determine which node features to combine. (default: None)

  • edge_attrs (List[str], optional) – The edge features to combine across all edge types. These edge features need to be of the same feature dimensionality. If set to None, will automatically determine which edge features to combine. (default: None)

  • add_edge_type (bool, optional) – If set to False, will not add the edge-level vector edge_type to the returned Data object. (default: True)

apply(func: Callable, *args: List[str])

Applies the function func, either to all attributes or only the ones given in *args.

apply_(func: Callable, *args: List[str])

Applies the in-place function func, either to all attributes or only the ones given in *args.

clone()

Performs a deep-copy of the data object.

coalesce()

Sorts and removes duplicated entries from edge indices edge_index.

contiguous(*args: List[str])

Ensures a contiguous memory layout, either for all attributes or only the ones given in *args.

cpu(*args: List[str])

Copies attributes to CPU memory, either for all attributes or only the ones given in *args.

cuda(device: Optional[Union[int, str]] = None, *args: List[str], non_blocking: bool = False)

Copies attributes to CUDA memory, either for all attributes or only the ones given in *args.

detach(*args: List[str])

Detaches attributes from the computation graph by creating a new tensor, either for all attributes or only the ones given in *args.

detach_(*args: List[str])

Detaches attributes from the computation graph, either for all attributes or only the ones given in *args.

has_isolated_nodes()bool

Returns True if the graph contains isolated nodes.

has_self_loops()bool

Returns True if the graph contains self-loops.

is_coalesced()bool

Returns True if edge indices edge_index are sorted and do not contain duplicate entries.

is_directed()bool

Returns True if graph edges are directed.

is_undirected()bool

Returns True if graph edges are undirected.

property keys: List[str]

Returns a list of all graph attribute names.

property num_edges: int

Returns the number of edges in the graph. For undirected graphs, this will return the number of bi-directional edges, which is double the amount of unique edges.

pin_memory(*args: List[str])

Copies attributes to pinned memory, either for all attributes or only the ones given in *args.

record_stream(stream: torch.cuda.streams.Stream, *args: List[str])

Ensures that the tensor memory is not reused for another tensor until all current work queued on stream has been completed, either for all attributes or only the ones given in *args.

requires_grad_(*args: List[str], requires_grad: bool = True)

Tracks gradient computation, either for all attributes or only the ones given in *args.

share_memory_(*args: List[str])

Moves attributes to shared memory, either for all attributes or only the ones given in *args.

size(dim: Optional[int] = None)Optional[Union[Tuple[Optional[int], Optional[int]], int]]

Returns the size of the adjacency matrix induced by the graph.

to(device: Union[int, str], *args: List[str], non_blocking: bool = False)

Performs tensor device conversion, either for all attributes or only the ones given in *args.

class Batch(*args, **kwargs)[source]

A data object describing a batch of graphs as one big (disconnected) graph. Inherits from torch_geometric.data.Data or torch_geometric.data.HeteroData. In addition, single graphs can be identified via the assignment vector batch, which maps each node to its respective graph identifier.

classmethod from_data_list(data_list: Union[List[torch_geometric.data.data.Data], List[torch_geometric.data.hetero_data.HeteroData]], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None)[source]

Constructs a Batch object from a Python list of Data or HeteroData objects. The assignment vector batch is created on the fly. In addition, creates assignment vectors for each key in follow_batch. Will exclude any keys given in exclude_keys.

get(idx: int)Union[torch_geometric.data.data.Data, torch_geometric.data.hetero_data.HeteroData][source]

Gets the Data or HeteroData object at index idx. The Batch object must have been created via from_data_list() in order to be able to reconstruct the initial object.

index_select(idx: Union[slice, torch.Tensor, numpy.ndarray, collections.abc.Sequence])Union[List[torch_geometric.data.data.Data], List[torch_geometric.data.hetero_data.HeteroData]][source]

Creates a subset of Data or HeteroData objects from specified indices idx. Indices idx can be a slicing object, e.g., [2:5], a list, a tuple, or a torch.Tensor or np.ndarray of type long or bool. The Batch object must have been created via from_data_list() in order to be able to reconstruct the initial objects.

to_data_list()Union[List[torch_geometric.data.data.Data], List[torch_geometric.data.hetero_data.HeteroData]][source]

Reconstructs the list of Data or HeteroData objects from the Batch object. The Batch object must have been created via from_data_list() in order to be able to reconstruct the initial objects.

property num_graphs: int

Returns the number of graphs in the batch.

class Dataset(root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None)[source]

Dataset base class for creating graph datasets. See here for the accompanying tutorial.

Parameters
  • root (string, optional) – Root directory where the dataset should be saved. (optional: None)

  • transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access. (default: None)

  • pre_transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: None)

  • pre_filter (callable, optional) – A function that takes in an torch_geometric.data.Data object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: None)

property raw_file_names: Union[str, List[str], Tuple]

The name of the files in the self.raw_dir folder that must be present in order to skip downloading.

property processed_file_names: Union[str, List[str], Tuple]

The name of the files in the self.processed_dir folder that must be present in order to skip processing.

download()[source]

Downloads the dataset to the self.raw_dir folder.

process()[source]

Processes the dataset to the self.processed_dir folder.

len()int[source]

Returns the number of graphs stored in the dataset.

get(idx: int)torch_geometric.data.data.Data[source]

Gets the data object at index idx.

property num_node_features: int

Returns the number of features per node in the dataset.

property num_features: int

Returns the number of features per node in the dataset. Alias for num_node_features.

property num_edge_features: int

Returns the number of features per edge in the dataset.

property raw_paths: List[str]

The absolute filepaths that must be present in order to skip downloading.

property processed_paths: List[str]

The absolute filepaths that must be present in order to skip processing.

index_select(idx: Union[slice, torch.Tensor, numpy.ndarray, collections.abc.Sequence])torch_geometric.data.dataset.Dataset[source]

Creates a subset of the dataset from specified indices idx. Indices idx can be a slicing object, e.g., [2:5], a list, a tuple, or a torch.Tensor or np.ndarray of type long or bool.

shuffle(return_perm: bool = False)Union[torch_geometric.data.dataset.Dataset, Tuple[torch_geometric.data.dataset.Dataset, torch.Tensor]][source]

Randomly shuffles the examples in the dataset.

Parameters

return_perm (bool, optional) – If set to True, will also return the random permutation used to shuffle the dataset. (default: False)

class InMemoryDataset(root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None)[source]

Dataset base class for creating graph datasets which easily fit into CPU memory. Inherits from torch_geometric.data.Dataset. See here for the accompanying tutorial.

Parameters
  • root (string, optional) – Root directory where the dataset should be saved. (default: None)

  • transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access. (default: None)

  • pre_transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: None)

  • pre_filter (callable, optional) – A function that takes in an torch_geometric.data.Data object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: None)

property raw_file_names: Union[str, List[str], Tuple]

The name of the files in the self.raw_dir folder that must be present in order to skip downloading.

property processed_file_names: Union[str, List[str], Tuple]

The name of the files in the self.processed_dir folder that must be present in order to skip processing.

download()[source]

Downloads the dataset to the self.raw_dir folder.

process()[source]

Processes the dataset to the self.processed_dir folder.

property num_classes: int

Returns the number of classes in the dataset.

len()int[source]

Returns the number of graphs stored in the dataset.

get(idx: int)torch_geometric.data.data.Data[source]

Gets the data object at index idx.

static collate(data_list: List[torch_geometric.data.data.Data])Tuple[torch_geometric.data.data.Data, Optional[Dict[str, torch.Tensor]]][source]

Collates a Python list of torch_geometric.data.Data objects to the internal storage format of InMemoryDataset.

copy(idx: Optional[Union[slice, torch.Tensor, numpy.ndarray, collections.abc.Sequence]] = None)torch_geometric.data.in_memory_dataset.InMemoryDataset[source]

Performs a deep-copy of the dataset. If idx is not given, will clone the full dataset. Otherwise, will only clone a subset of the dataset from indices idx. Indices can be slices, lists, tuples, and a torch.Tensor or np.ndarray of type long or bool.

download_url(url: str, folder: str, log: bool = True)[source]

Downloads the content of an URL to a specific folder.

Parameters
  • url (string) – The url.

  • folder (string) – The folder.

  • log (bool, optional) – If False, will not print anything to the console. (default: True)

extract_tar(path: str, folder: str, mode: str = 'r:gz', log: bool = True)[source]

Extracts a tar archive to a specific folder.

Parameters
  • path (string) – The path to the tar archive.

  • folder (string) – The folder.

  • mode (string, optional) – The compression mode. (default: "r:gz")

  • log (bool, optional) – If False, will not print anything to the console. (default: True)

extract_zip(path: str, folder: str, log: bool = True)[source]

Extracts a zip archive to a specific folder.

Parameters
  • path (string) – The path to the tar archive.

  • folder (string) – The folder.

  • log (bool, optional) – If False, will not print anything to the console. (default: True)

extract_bz2(path: str, folder: str, log: bool = True)[source]

Extracts a bz2 archive to a specific folder.

Parameters
  • path (string) – The path to the tar archive.

  • folder (string) – The folder.

  • log (bool, optional) – If False, will not print anything to the console. (default: True)

extract_gz(path: str, folder: str, log: bool = True)[source]

Extracts a gz archive to a specific folder.

Parameters
  • path (string) – The path to the tar archive.

  • folder (string) – The folder.

  • log (bool, optional) – If False, will not print anything to the console. (default: True)