torch_geometric.data

Data

A data object describing a homogeneous graph.

HeteroData

A data object describing a heterogeneous graph, holding multiple node and/or edge types in disjunct storage objects.

TemporalData

A data object composed by a stream of events describing a temporal graph.

Batch

A data object describing a batch of graphs as one big (disconnected) graph.

Dataset

Dataset base class for creating graph datasets.

InMemoryDataset

Dataset base class for creating graph datasets which easily fit into CPU memory.

LightningDataset

Converts a set of Dataset objects into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU graph-level training via PyTorch Lightning.

LightningNodeData

Converts a Data or HeteroData object into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU node-level training via PyTorch Lightning.

LightningLinkData

Converts a Data or HeteroData object into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU link-level training (such as for link prediction) via PyTorch Lightning.

makedirs

Recursive directory creation function.

download_url

Downloads the content of an URL to a specific folder.

extract_tar

Extracts a tar archive to a specific folder.

extract_zip

Extracts a zip archive to a specific folder.

extract_bz2

Extracts a bz2 archive to a specific folder.

extract_gz

Extracts a gz archive to a specific folder.

class Data(x: Optional[Tensor] = None, edge_index: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, y: Optional[Tensor] = None, pos: Optional[Tensor] = None, **kwargs)[source]

A data object describing a homogeneous graph. The data object can hold node-level, link-level and graph-level attributes. In general, Data tries to mimic the behaviour of a regular Python dictionary. In addition, it provides useful functionality for analyzing graph structures, and provides basic PyTorch tensor functionalities. See here for the accompanying tutorial.

from torch_geometric.data import Data

data = Data(x=x, edge_index=edge_index, ...)

# Add additional arguments to `data`:
data.train_idx = torch.tensor([...], dtype=torch.long)
data.test_mask = torch.tensor([...], dtype=torch.bool)

# Analyzing the graph structure:
data.num_nodes
>>> 23

data.is_directed()
>>> False

# PyTorch tensor functionality:
data = data.pin_memory()
data = data.to('cuda:0', non_blocking=True)
Parameters
  • x (Tensor, optional) – Node feature matrix with shape [num_nodes, num_node_features]. (default: None)

  • edge_index (LongTensor, optional) – Graph connectivity in COO format with shape [2, num_edges]. (default: None)

  • edge_attr (Tensor, optional) – Edge feature matrix with shape [num_edges, num_edge_features]. (default: None)

  • y (Tensor, optional) – Graph-level or node-level ground-truth labels with arbitrary shape. (default: None)

  • pos (Tensor, optional) – Node position matrix with shape [num_nodes, num_dimensions]. (default: None)

  • **kwargs (optional) – Additional attributes.

to_dict() Dict[str, Any][source]

Returns a dictionary of stored key/value pairs.

to_namedtuple() NamedTuple[source]

Returns a NamedTuple of stored key/value pairs.

__cat_dim__(key: str, value: Any, *args, **kwargs) Any[source]

Returns the dimension for which the value value of the attribute key will get concatenated when creating mini-batches using torch_geometric.loader.DataLoader.

Note

This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute.

__inc__(key: str, value: Any, *args, **kwargs) Any[source]

Returns the incremental count to cumulatively increase the value value of the attribute key when creating mini-batches using torch_geometric.loader.DataLoader.

Note

This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute.

validate(raise_on_error: bool = True) bool[source]

Validates the correctness of the data.

is_node_attr(key: str) bool[source]

Returns True if the object at key key denotes a node-level attribute.

is_edge_attr(key: str) bool[source]

Returns True if the object at key key denotes an edge-level attribute.

subgraph(subset: Tensor)[source]

Returns the induced subgraph given by the node indices subset.

Parameters

subset (LongTensor or BoolTensor) – The nodes to keep.

to_heterogeneous(node_type: Optional[Tensor] = None, edge_type: Optional[Tensor] = None, node_type_names: Optional[List[str]] = None, edge_type_names: Optional[List[Tuple[str, str, str]]] = None)[source]

Converts a Data object to a heterogeneous HeteroData object. For this, node and edge attributes are splitted according to the node-level and edge-level vectors node_type and edge_type, respectively. node_type_names and edge_type_names can be used to give meaningful node and edge type names, respectively. That is, the node_type 0 is given by node_type_names[0]. If the Data object was constructed via to_homogeneous(), the object can be reconstructed without any need to pass in additional arguments.

Parameters
  • node_type (Tensor, optional) – A node-level vector denoting the type of each node. (default: None)

  • edge_type (Tensor, optional) – An edge-level vector denoting the type of each edge. (default: None)

  • node_type_names (List[str], optional) – The names of node types. (default: None)

  • edge_type_names (List[Tuple[str, str, str]], optional) – The names of edge types. (default: None)

classmethod from_dict(mapping: Dict[str, Any])[source]

Creates a Data object from a Python dictionary.

property num_node_features: int

Returns the number of features per node in the graph.

property num_features: int

Returns the number of features per node in the graph. Alias for num_node_features.

property num_edge_features: int

Returns the number of features per edge in the graph.

property num_faces: Optional[int]

Returns the number of faces in the mesh.

items()[source]

Returns an ItemsView over the stored attributes in the Data object.

get_all_tensor_attrs() List[TensorAttr][source]

Obtains all feature attributes stored in Data.

apply(func: Callable, *args: List[str])

Applies the function func, either to all attributes or only the ones given in *args.

apply_(func: Callable, *args: List[str])

Applies the in-place function func, either to all attributes or only the ones given in *args.

clone(*args: List[str])

Performs cloning of tensors, either for all attributes or only the ones given in *args.

coalesce()

Sorts and removes duplicated entries from edge indices edge_index.

contiguous(*args: List[str])

Ensures a contiguous memory layout, either for all attributes or only the ones given in *args.

coo(edge_types: Optional[List[Any]] = None, store: bool = False) Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, Optional[Tensor]]]

Converts the edge indices in the graph store to COO format, optionally storing the converted edge indices in the graph store.

cpu(*args: List[str])

Copies attributes to CPU memory, either for all attributes or only the ones given in *args.

csc(edge_types: Optional[List[Any]] = None, store: bool = False) Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, Optional[Tensor]]]

Converts the edge indices in the graph store to CSC format, optionally storing the converted edge indices in the graph store.

csr(edge_types: Optional[List[Any]] = None, store: bool = False) Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, Optional[Tensor]]]

Converts the edge indices in the graph store to CSR format, optionally storing the converted edge indices in the graph store.

cuda(device: Optional[Union[int, str]] = None, *args: List[str], non_blocking: bool = False)

Copies attributes to CUDA memory, either for all attributes or only the ones given in *args.

detach(*args: List[str])

Detaches attributes from the computation graph by creating a new tensor, either for all attributes or only the ones given in *args.

detach_(*args: List[str])

Detaches attributes from the computation graph, either for all attributes or only the ones given in *args.

get_edge_index(*args, **kwargs) Tuple[Tensor, Tensor]

Synchronously gets an edge_index tensor from the materialized graph.

Parameters

**attr (EdgeAttr) – the edge attributes.

Returns

an edge_index tensor corresonding to the provided attributes, or None if there is no such tensor.

Return type

EdgeTensorType

Raises

KeyError – if the edge index corresponding to attr was not found.

get_tensor(*args, **kwargs) Union[Tensor, ndarray]

Synchronously obtains a FeatureTensorType object from the feature store. Feature store implementors guarantee that the call get_tensor(put_tensor(tensor, attr), attr) = tensor holds.

Parameters

**attr (TensorAttr) – Any relevant tensor attributes that correspond to the feature tensor. See the TensorAttr documentation for required and optional attributes. It is the job of implementations of a FeatureStore to store this metadata in a meaningful way that allows for tensor retrieval from a TensorAttr object.

Returns

a Tensor of the same type as the index.

Return type

FeatureTensorType

Raises
  • KeyError – if the tensor corresponding to attr was not found.

  • ValueError – if the input TensorAttr is not fully specified.

get_tensor_size(*args, **kwargs) Tuple

Obtains the size of a tensor given its attributes, or None if the tensor does not exist.

has_isolated_nodes() bool

Returns True if the graph contains isolated nodes.

has_self_loops() bool

Returns True if the graph contains self-loops.

is_coalesced() bool

Returns True if edge indices edge_index are sorted and do not contain duplicate entries.

property is_cuda: bool

Returns True if any torch.Tensor attribute is stored on the GPU, False otherwise.

is_directed() bool

Returns True if graph edges are directed.

is_undirected() bool

Returns True if graph edges are undirected.

property keys: List[str]

Returns a list of all graph attribute names.

multi_get_tensor(attrs: List[TensorAttr]) List[Union[Tensor, ndarray]]

Synchronously obtains a FeatureTensorType object from the feature store for each tensor associated with the attributes in attrs.

Parameters

attrs (List[TensorAttr]) – a list of TensorAttr attributes that identify the tensors to get.

Returns

a Tensor of the same type as the index for

each attribute.

Return type

List[FeatureTensorType]

Raises
  • KeyError – if a tensor corresponding to an attr was not found.

  • ValueError – if any input TensorAttr is not fully specified.

property num_edges: int

Returns the number of edges in the graph. For undirected graphs, this will return the number of bi-directional edges, which is double the amount of unique edges.

property num_nodes: Optional[int]

Returns the number of nodes in the graph.

Note

The number of nodes in the data object is automatically inferred in case node-level attributes are present, e.g., data.x. In some cases, however, a graph may only be given without any node-level attributes. PyG then guesses the number of nodes according to edge_index.max().item() + 1. However, in case there exists isolated nodes, this number does not have to be correct which can result in unexpected behaviour. Thus, we recommend to set the number of nodes in your data object explicitly via data.num_nodes = .... You will be given a warning that requests you to do so.

pin_memory(*args: List[str])

Copies attributes to pinned memory, either for all attributes or only the ones given in *args.

put_edge_index(edge_index: Tuple[Tensor, Tensor], *args, **kwargs) bool

Synchronously adds an edge_index tensor to the graph store.

Parameters
  • tensor (EdgeTensorType) – an edge_index in a format specified in

  • attr.

  • **attr (EdgeAttr) – the edge attributes.

put_tensor(tensor: Union[Tensor, ndarray], *args, **kwargs) bool

Synchronously adds a FeatureTensorType object to the feature store.

Parameters
  • tensor (FeatureTensorType) – The feature tensor to be added.

  • **attr (TensorAttr) – Any relevant tensor attributes that correspond to the feature tensor. See the TensorAttr documentation for required and optional attributes. It is the job of implementations of a FeatureStore to store this metadata in a meaningful way that allows for tensor retrieval from a TensorAttr object.

Returns

Whether insertion was successful.

Return type

bool

Raises

ValueError – if the input TensorAttr is not fully specified.

record_stream(stream: Stream, *args: List[str])

Ensures that the tensor memory is not reused for another tensor until all current work queued on stream has been completed, either for all attributes or only the ones given in *args.

remove_tensor(*args, **kwargs) bool

Removes a FeatureTensorType object from the feature store.

Parameters

**attr (TensorAttr) – Any relevant tensor attributes that correspond to the feature tensor. See the TensorAttr documentation for required and optional attributes. It is the job of implementations of a FeatureStore to store this metadata in a meaningful way that allows for tensor retrieval from a TensorAttr object.

Returns

Whether deletion was succesful.

Return type

bool

Raises

ValueError – if the input TensorAttr is not fully specified.

requires_grad_(*args: List[str], requires_grad: bool = True)

Tracks gradient computation, either for all attributes or only the ones given in *args.

share_memory_(*args: List[str])

Moves attributes to shared memory, either for all attributes or only the ones given in *args.

size(dim: Optional[int] = None) Optional[Union[Tuple[Optional[int], Optional[int]], int]]

Returns the size of the adjacency matrix induced by the graph.

to(device: Union[int, str], *args: List[str], non_blocking: bool = False)

Performs tensor device conversion, either for all attributes or only the ones given in *args.

update_tensor(tensor: Union[Tensor, ndarray], *args, **kwargs) bool

Updates a FeatureTensorType object with a new value. implementor classes can choose to define more efficient update methods; the default performs a removal and insertion.

Parameters
  • tensor (FeatureTensorType) – The feature tensor to be updated.

  • **attr (TensorAttr) – Any relevant tensor attributes that correspond to the feature tensor. See the TensorAttr documentation for required and optional attributes. It is the job of implementations of a FeatureStore to store this metadata in a meaningful way that allows for tensor retrieval from a TensorAttr object.

Returns

Whether the update was succesful.

Return type

bool

view(*args, **kwargs) AttrView

Returns an AttrView of the feature store, with the defined attributes set.

get_all_edge_attrs() List[EdgeAttr][source]

Returns EdgeAttr objects corresponding to the edge indices stored in Data and their layouts

class HeteroData(_mapping: Optional[Dict[str, Any]] = None, **kwargs)[source]

A data object describing a heterogeneous graph, holding multiple node and/or edge types in disjunct storage objects. Storage objects can hold either node-level, link-level or graph-level attributes. In general, HeteroData tries to mimic the behaviour of a regular nested Python dictionary. In addition, it provides useful functionality for analyzing graph structures, and provides basic PyTorch tensor functionalities.

from torch_geometric.data import HeteroData

data = HeteroData()

# Create two node types "paper" and "author" holding a feature matrix:
data['paper'].x = torch.randn(num_papers, num_paper_features)
data['author'].x = torch.randn(num_authors, num_authors_features)

# Create an edge type "(author, writes, paper)" and building the
# graph connectivity:
data['author', 'writes', 'paper'].edge_index = ...  # [2, num_edges]

data['paper'].num_nodes
>>> 23

data['author', 'writes', 'paper'].num_edges
>>> 52

# PyTorch tensor functionality:
data = data.pin_memory()
data = data.to('cuda:0', non_blocking=True)

Note that there exists multiple ways to create a heterogeneous graph data, e.g.:

  • To initialize a node of type "paper" holding a node feature matrix x_paper named x:

    from torch_geometric.data import HeteroData
    
    data = HeteroData()
    data['paper'].x = x_paper
    
    data = HeteroData(paper={ 'x': x_paper })
    
    data = HeteroData({'paper': { 'x': x_paper }})
    
  • To initialize an edge from source node type "author" to destination node type "paper" with relation type "writes" holding a graph connectivity matrix edge_index_author_paper named edge_index:

    data = HeteroData()
    data['author', 'writes', 'paper'].edge_index = edge_index_author_paper
    
    data = HeteroData(author__writes__paper={
        'edge_index': edge_index_author_paper
    })
    
    data = HeteroData({
        ('author', 'writes', 'paper'):
        { 'edge_index': edge_index_author_paper }
    })
    
property stores: List[BaseStorage]

Returns a list of all storages of the graph.

property node_types: List[str]

Returns a list of all node types of the graph.

property node_stores: List[NodeStorage]

Returns a list of all node storages of the graph.

property edge_types: List[Tuple[str, str, str]]

Returns a list of all edge types of the graph.

property edge_stores: List[EdgeStorage]

Returns a list of all edge storages of the graph.

node_items() List[Tuple[str, NodeStorage]][source]

Returns a list of node type and node storage pairs.

edge_items() List[Tuple[Tuple[str, str, str], EdgeStorage]][source]

Returns a list of edge type and edge storage pairs.

to_dict() Dict[str, Any][source]

Returns a dictionary of stored key/value pairs.

to_namedtuple() NamedTuple[source]

Returns a NamedTuple of stored key/value pairs.

__cat_dim__(key: str, value: Any, store: Optional[Union[NodeStorage, EdgeStorage]] = None, *args, **kwargs) Any[source]

Returns the dimension for which the value value of the attribute key will get concatenated when creating mini-batches using torch_geometric.loader.DataLoader.

Note

This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute.

__inc__(key: str, value: Any, store: Optional[Union[NodeStorage, EdgeStorage]] = None, *args, **kwargs) Any[source]

Returns the incremental count to cumulatively increase the value value of the attribute key when creating mini-batches using torch_geometric.loader.DataLoader.

Note

This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute.

property num_nodes: Optional[int]

Returns the number of nodes in the graph.

property num_node_features: Dict[str, int]

Returns the number of features per node type in the graph.

property num_features: Dict[str, int]

Returns the number of features per node type in the graph. Alias for num_node_features.

property num_edge_features: Dict[Tuple[str, str, str], int]

Returns the number of features per edge type in the graph.

has_isolated_nodes() bool[source]

Returns True if the graph contains isolated nodes.

is_undirected() bool[source]

Returns True if graph edges are undirected.

validate(raise_on_error: bool = True) bool[source]

Validates the correctness of the data.

metadata() Tuple[List[str], List[Tuple[str, str, str]]][source]

Returns the heterogeneous meta-data, i.e. its node and edge types.

data = HeteroData()
data['paper'].x = ...
data['author'].x = ...
data['author', 'writes', 'paper'].edge_index = ...

print(data.metadata())
>>> (['paper', 'author'], [('author', 'writes', 'paper')])
collect(key: str) Dict[Union[str, Tuple[str, str, str]], Any][source]

Collects the attribute key from all node and edge types.

data = HeteroData()
data['paper'].x = ...
data['author'].x = ...

print(data.collect('x'))
>>> { 'paper': ..., 'author': ...}

Note

This is equivalent to writing data.x_dict.

get_node_store(key: str) NodeStorage[source]

Gets the NodeStorage object of a particular node type key. If the storage is not present yet, will create a new torch_geometric.data.storage.NodeStorage object for the given node type.

data = HeteroData()
node_storage = data.get_node_store('paper')
get_edge_store(src: str, rel: str, dst: str) EdgeStorage[source]

Gets the EdgeStorage object of a particular edge type given by the tuple (src, rel, dst). If the storage is not present yet, will create a new torch_geometric.data.storage.EdgeStorage object for the given edge type.

data = HeteroData()
edge_storage = data.get_edge_store('author', 'writes', 'paper')
rename(name: str, new_name: str) HeteroData[source]

Renames the node type name to new_name in-place.

subgraph(subset_dict: Dict[str, Tensor]) HeteroData[source]

Returns the induced subgraph containing the node types and corresponding nodes in subset_dict.

data = HeteroData()
data['paper'].x = ...
data['author'].x = ...
data['conference'].x = ...
data['paper', 'cites', 'paper'].edge_index = ...
data['author', 'paper'].edge_index = ...
data['paper', 'conference'].edge_index = ...
print(data)
>>> HeteroData(
    paper={ x=[10, 16] },
    author={ x=[5, 32] },
    conference={ x=[5, 8] },
    (paper, cites, paper)={ edge_index=[2, 50] },
    (author, to, paper)={ edge_index=[2, 30] },
    (paper, to, conference)={ edge_index=[2, 25] }
)

subset_dict = {
    'paper': torch.tensor([3, 4, 5, 6]),
    'author': torch.tensor([0, 2]),
}

print(data.subgraph(subset_dict))
>>> HeteroData(
    paper={ x=[4, 16] },
    author={ x=[2, 32] },
    (paper, cites, paper)={ edge_index=[2, 24] },
    (author, to, paper)={ edge_index=[2, 5] }
)
Parameters

subset_dict (Dict[str, LongTensor or BoolTensor]) – A dictonary holding the nodes to keep for each node type.

node_type_subgraph(node_types: List[str]) HeteroData[source]

Returns the subgraph induced by the given node_types, i.e. the returned HeteroData object only contains the node types which are included in node_types, and only contains the edge types where both end points are included in node_types.

edge_type_subgraph(edge_types: List[Tuple[str, str, str]]) HeteroData[source]

Returns the subgraph induced by the given edge_types, i.e. the returned HeteroData object only contains the edge types which are included in edge_types, and only contains the node types of the end points which are included in node_types.

to_homogeneous(node_attrs: Optional[List[str]] = None, edge_attrs: Optional[List[str]] = None, add_node_type: bool = True, add_edge_type: bool = True, dummy_values: bool = True) Data[source]

Converts a HeteroData object to a homogeneous Data object. By default, all features with same feature dimensionality across different types will be merged into a single representation, unless otherwise specified via the node_attrs and edge_attrs arguments. Furthermore, attributes named node_type and edge_type will be added to the returned Data object, denoting node-level and edge-level vectors holding the node and edge type as integers, respectively.

Parameters
  • node_attrs (List[str], optional) – The node features to combine across all node types. These node features need to be of the same feature dimensionality. If set to None, will automatically determine which node features to combine. (default: None)

  • edge_attrs (List[str], optional) – The edge features to combine across all edge types. These edge features need to be of the same feature dimensionality. If set to None, will automatically determine which edge features to combine. (default: None)

  • add_node_type (bool, optional) – If set to False, will not add the node-level vector node_type to the returned Data object. (default: True)

  • add_edge_type (bool, optional) – If set to False, will not add the edge-level vector edge_type to the returned Data object. (default: True)

  • dummy_values (bool, optional) – If set to True, will fill attributes of remaining types with dummy values. Dummy values are NaN for floating point attributes, and -1 for integers. (default: True)

get_all_tensor_attrs() List[TensorAttr][source]

Obtains all tensor attributes stored in this feature store.

get_all_edge_attrs() List[EdgeAttr][source]

Returns a list of EdgeAttr objects corresponding to the edge indices stored in HeteroData and their layouts.

apply(func: Callable, *args: List[str])

Applies the function func, either to all attributes or only the ones given in *args.

apply_(func: Callable, *args: List[str])

Applies the in-place function func, either to all attributes or only the ones given in *args.

clone(*args: List[str])

Performs cloning of tensors, either for all attributes or only the ones given in *args.

coalesce()

Sorts and removes duplicated entries from edge indices edge_index.

contiguous(*args: List[str])

Ensures a contiguous memory layout, either for all attributes or only the ones given in *args.

coo(edge_types: Optional[List[Any]] = None, store: bool = False) Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, Optional[Tensor]]]

Converts the edge indices in the graph store to COO format, optionally storing the converted edge indices in the graph store.

cpu(*args: List[str])

Copies attributes to CPU memory, either for all attributes or only the ones given in *args.

csc(edge_types: Optional[List[Any]] = None, store: bool = False) Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, Optional[Tensor]]]

Converts the edge indices in the graph store to CSC format, optionally storing the converted edge indices in the graph store.

csr(edge_types: Optional[List[Any]] = None, store: bool = False) Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, Optional[Tensor]]]

Converts the edge indices in the graph store to CSR format, optionally storing the converted edge indices in the graph store.

cuda(device: Optional[Union[int, str]] = None, *args: List[str], non_blocking: bool = False)

Copies attributes to CUDA memory, either for all attributes or only the ones given in *args.

detach(*args: List[str])

Detaches attributes from the computation graph by creating a new tensor, either for all attributes or only the ones given in *args.

detach_(*args: List[str])

Detaches attributes from the computation graph, either for all attributes or only the ones given in *args.

get_edge_index(*args, **kwargs) Tuple[Tensor, Tensor]

Synchronously gets an edge_index tensor from the materialized graph.

Parameters

**attr (EdgeAttr) – the edge attributes.

Returns

an edge_index tensor corresonding to the provided attributes, or None if there is no such tensor.

Return type

EdgeTensorType

Raises

KeyError – if the edge index corresponding to attr was not found.

get_tensor(*args, **kwargs) Union[Tensor, ndarray]

Synchronously obtains a FeatureTensorType object from the feature store. Feature store implementors guarantee that the call get_tensor(put_tensor(tensor, attr), attr) = tensor holds.

Parameters

**attr (TensorAttr) – Any relevant tensor attributes that correspond to the feature tensor. See the TensorAttr documentation for required and optional attributes. It is the job of implementations of a FeatureStore to store this metadata in a meaningful way that allows for tensor retrieval from a TensorAttr object.

Returns

a Tensor of the same type as the index.

Return type

FeatureTensorType

Raises
  • KeyError – if the tensor corresponding to attr was not found.

  • ValueError – if the input TensorAttr is not fully specified.

get_tensor_size(*args, **kwargs) Tuple

Obtains the size of a tensor given its attributes, or None if the tensor does not exist.

has_self_loops() bool

Returns True if the graph contains self-loops.

is_coalesced() bool

Returns True if edge indices edge_index are sorted and do not contain duplicate entries.

property is_cuda: bool

Returns True if any torch.Tensor attribute is stored on the GPU, False otherwise.

is_directed() bool

Returns True if graph edges are directed.

property keys: List[str]

Returns a list of all graph attribute names.

multi_get_tensor(attrs: List[TensorAttr]) List[Union[Tensor, ndarray]]

Synchronously obtains a FeatureTensorType object from the feature store for each tensor associated with the attributes in attrs.

Parameters

attrs (List[TensorAttr]) – a list of TensorAttr attributes that identify the tensors to get.

Returns

a Tensor of the same type as the index for

each attribute.

Return type

List[FeatureTensorType]

Raises
  • KeyError – if a tensor corresponding to an attr was not found.

  • ValueError – if any input TensorAttr is not fully specified.

property num_edges: int

Returns the number of edges in the graph. For undirected graphs, this will return the number of bi-directional edges, which is double the amount of unique edges.

pin_memory(*args: List[str])

Copies attributes to pinned memory, either for all attributes or only the ones given in *args.

put_edge_index(edge_index: Tuple[Tensor, Tensor], *args, **kwargs) bool

Synchronously adds an edge_index tensor to the graph store.

Parameters
  • tensor (EdgeTensorType) – an edge_index in a format specified in

  • attr.

  • **attr (EdgeAttr) – the edge attributes.

put_tensor(tensor: Union[Tensor, ndarray], *args, **kwargs) bool

Synchronously adds a FeatureTensorType object to the feature store.

Parameters
  • tensor (FeatureTensorType) – The feature tensor to be added.

  • **attr (TensorAttr) – Any relevant tensor attributes that correspond to the feature tensor. See the TensorAttr documentation for required and optional attributes. It is the job of implementations of a FeatureStore to store this metadata in a meaningful way that allows for tensor retrieval from a TensorAttr object.

Returns

Whether insertion was successful.

Return type

bool

Raises

ValueError – if the input TensorAttr is not fully specified.

record_stream(stream: Stream, *args: List[str])

Ensures that the tensor memory is not reused for another tensor until all current work queued on stream has been completed, either for all attributes or only the ones given in *args.

remove_tensor(*args, **kwargs) bool

Removes a FeatureTensorType object from the feature store.

Parameters

**attr (TensorAttr) – Any relevant tensor attributes that correspond to the feature tensor. See the TensorAttr documentation for required and optional attributes. It is the job of implementations of a FeatureStore to store this metadata in a meaningful way that allows for tensor retrieval from a TensorAttr object.

Returns

Whether deletion was succesful.

Return type

bool

Raises

ValueError – if the input TensorAttr is not fully specified.

requires_grad_(*args: List[str], requires_grad: bool = True)

Tracks gradient computation, either for all attributes or only the ones given in *args.

share_memory_(*args: List[str])

Moves attributes to shared memory, either for all attributes or only the ones given in *args.

size(dim: Optional[int] = None) Optional[Union[Tuple[Optional[int], Optional[int]], int]]

Returns the size of the adjacency matrix induced by the graph.

to(device: Union[int, str], *args: List[str], non_blocking: bool = False)

Performs tensor device conversion, either for all attributes or only the ones given in *args.

update_tensor(tensor: Union[Tensor, ndarray], *args, **kwargs) bool

Updates a FeatureTensorType object with a new value. implementor classes can choose to define more efficient update methods; the default performs a removal and insertion.

Parameters
  • tensor (FeatureTensorType) – The feature tensor to be updated.

  • **attr (TensorAttr) – Any relevant tensor attributes that correspond to the feature tensor. See the TensorAttr documentation for required and optional attributes. It is the job of implementations of a FeatureStore to store this metadata in a meaningful way that allows for tensor retrieval from a TensorAttr object.

Returns

Whether the update was succesful.

Return type

bool

view(*args, **kwargs) AttrView

Returns an AttrView of the feature store, with the defined attributes set.

class TemporalData(src: Optional[Tensor] = None, dst: Optional[Tensor] = None, t: Optional[Tensor] = None, msg: Optional[Tensor] = None, **kwargs)[source]

A data object composed by a stream of events describing a temporal graph. The TemporalData object can hold a list of events (that can be understood as temporal edges in a graph) with structured messages. An event is composed by a source node, a destination node, a timestamp and a message. Any Continuous-Time Dynamic Graph (CTDG) can be represented with these four values.

In general, TemporalData tries to mimic the behaviour of a regular Python dictionary. In addition, it provides useful functionality for analyzing graph structures, and provides basic PyTorch tensor functionalities.

from torch import Tensor
from torch_geometric.data import TemporalData

events = TemporalData(
    src=Tensor([1,2,3,4]),
    dst=Tensor([2,3,4,5]),
    t=Tensor([1000,1010,1100,2000]),
    msg=Tensor([1,1,0,0])
)

# Add additional arguments to `events`:
events.y = Tensor([1,1,0,0])

# It is also possible to set additional arguments in the constructor
events = TemporalData(
    ...,
    y=Tensor([1,1,0,0])
)

# Get the number of events:
events.num_events
>>> 4

# Analyzing the graph structure:
events.num_nodes
>>> 5

# PyTorch tensor functionality:
events = events.pin_memory()
events = events.to('cuda:0', non_blocking=True)
Parameters
  • src (Tensor, optional) – A list of source nodes for the events with shape [num_events]. (default: None)

  • dst (Tensor, optional) – A list of destination nodes for the events with shape [num_events]. (default: None)

  • t (Tensor, optional) – The timestamps for each event with shape [num_events]. (default: None)

  • msg (Tensor, optional) – Messages feature matrix with shape [num_events, num_msg_features]. (default: None)

  • **kwargs (optional) – Additional attributes.

Note

The shape of src, dst, t and the first dimension of :obj`msg` should be the same (num_events).

to_dict() Dict[str, Any][source]

Returns a dictionary of stored key/value pairs.

to_namedtuple() NamedTuple[source]

Returns a NamedTuple of stored key/value pairs.

property num_nodes: int

Returns the number of nodes in the graph.

property num_events: int

Returns the number of events loaded.

Note

In a TemporalData, each row denotes an event. Thus, they can be also understood as edges.

property num_edges: int

Alias for num_events().

size(dim: Optional[int] = None) Optional[Union[Tuple[Optional[int], Optional[int]], int]][source]

Returns the size of the adjacency matrix induced by the graph.

__cat_dim__(key: str, value: Any, *args, **kwargs) Any[source]

Returns the dimension for which the value value of the attribute key will get concatenated when creating mini-batches using torch_geometric.loader.DataLoader.

Note

This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute.

__inc__(key: str, value: Any, *args, **kwargs) Any[source]

Returns the incremental count to cumulatively increase the value value of the attribute key when creating mini-batches using torch_geometric.loader.DataLoader.

Note

This method is for internal use only, and should only be overridden in case the mini-batch creation process is corrupted for a specific attribute.

train_val_test_split(val_ratio: float = 0.15, test_ratio: float = 0.15)[source]

Splits the data in training, validation and test sets according to time.

Parameters
  • val_ratio (float, optional) – The proportion (in percents) of the dataset to include in the validation split. (default: 0.15)

  • test_ratio (float, optional) – The proportion (in percents) of the dataset to include in the test split. (default: 0.15)

apply(func: Callable, *args: List[str])

Applies the function func, either to all attributes or only the ones given in *args.

apply_(func: Callable, *args: List[str])

Applies the in-place function func, either to all attributes or only the ones given in *args.

clone(*args: List[str])

Performs cloning of tensors, either for all attributes or only the ones given in *args.

contiguous(*args: List[str])

Ensures a contiguous memory layout, either for all attributes or only the ones given in *args.

cpu(*args: List[str])

Copies attributes to CPU memory, either for all attributes or only the ones given in *args.

cuda(device: Optional[Union[int, str]] = None, *args: List[str], non_blocking: bool = False)

Copies attributes to CUDA memory, either for all attributes or only the ones given in *args.

detach(*args: List[str])

Detaches attributes from the computation graph by creating a new tensor, either for all attributes or only the ones given in *args.

detach_(*args: List[str])

Detaches attributes from the computation graph, either for all attributes or only the ones given in *args.

is_coalesced() bool

Returns True if edge indices edge_index are sorted and do not contain duplicate entries.

property is_cuda: bool

Returns True if any torch.Tensor attribute is stored on the GPU, False otherwise.

property keys: List[str]

Returns a list of all graph attribute names.

pin_memory(*args: List[str])

Copies attributes to pinned memory, either for all attributes or only the ones given in *args.

record_stream(stream: Stream, *args: List[str])

Ensures that the tensor memory is not reused for another tensor until all current work queued on stream has been completed, either for all attributes or only the ones given in *args.

requires_grad_(*args: List[str], requires_grad: bool = True)

Tracks gradient computation, either for all attributes or only the ones given in *args.

share_memory_(*args: List[str])

Moves attributes to shared memory, either for all attributes or only the ones given in *args.

to(device: Union[int, str], *args: List[str], non_blocking: bool = False)

Performs tensor device conversion, either for all attributes or only the ones given in *args.

class Batch(*args, **kwargs)[source]

A data object describing a batch of graphs as one big (disconnected) graph. Inherits from torch_geometric.data.Data or torch_geometric.data.HeteroData. In addition, single graphs can be identified via the assignment vector batch, which maps each node to its respective graph identifier.

classmethod from_data_list(data_list: List[BaseData], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None)[source]

Constructs a Batch object from a Python list of Data or HeteroData objects. The assignment vector batch is created on the fly. In addition, creates assignment vectors for each key in follow_batch. Will exclude any keys given in exclude_keys.

get_example(idx: int) BaseData[source]

Gets the Data or HeteroData object at index idx. The Batch object must have been created via from_data_list() in order to be able to reconstruct the initial object.

index_select(idx: Union[slice, Tensor, ndarray, Sequence]) List[BaseData][source]

Creates a subset of Data or HeteroData objects from specified indices idx. Indices idx can be a slicing object, e.g., [2:5], a list, a tuple, or a torch.Tensor or np.ndarray of type long or bool. The Batch object must have been created via from_data_list() in order to be able to reconstruct the initial objects.

to_data_list() List[BaseData][source]

Reconstructs the list of Data or HeteroData objects from the Batch object. The Batch object must have been created via from_data_list() in order to be able to reconstruct the initial objects.

property num_graphs: int

Returns the number of graphs in the batch.

class Dataset(root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, log: bool = True)[source]

Dataset base class for creating graph datasets. See here for the accompanying tutorial.

Parameters
  • root (string, optional) – Root directory where the dataset should be saved. (optional: None)

  • transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access. (default: None)

  • pre_transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: None)

  • pre_filter (callable, optional) – A function that takes in an torch_geometric.data.Data object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: None)

  • log (bool, optional) – Whether to print any console output while downloading and processing the dataset. (default: True)

property raw_file_names: Union[str, List[str], Tuple]

The name of the files in the self.raw_dir folder that must be present in order to skip downloading.

property processed_file_names: Union[str, List[str], Tuple]

The name of the files in the self.processed_dir folder that must be present in order to skip processing.

download()[source]

Downloads the dataset to the self.raw_dir folder.

process()[source]

Processes the dataset to the self.processed_dir folder.

len() int[source]

Returns the number of graphs stored in the dataset.

get(idx: int) Data[source]

Gets the data object at index idx.

property num_node_features: int

Returns the number of features per node in the dataset.

property num_features: int

Returns the number of features per node in the dataset. Alias for num_node_features.

property num_edge_features: int

Returns the number of features per edge in the dataset.

property num_classes: int

Returns the number of classes in the dataset.

property raw_paths: List[str]

The absolute filepaths that must be present in order to skip downloading.

property processed_paths: List[str]

The absolute filepaths that must be present in order to skip processing.

index_select(idx: Union[slice, Tensor, ndarray, Sequence]) Dataset[source]

Creates a subset of the dataset from specified indices idx. Indices idx can be a slicing object, e.g., [2:5], a list, a tuple, or a torch.Tensor or np.ndarray of type long or bool.

shuffle(return_perm: bool = False) Union[Dataset, Tuple[Dataset, Tensor]][source]

Randomly shuffles the examples in the dataset.

Parameters

return_perm (bool, optional) – If set to True, will also return the random permutation used to shuffle the dataset. (default: False)

get_summary()[source]

Collects summary statistics for the dataset.

print_summary()[source]

Prints summary statistics of the dataset to the console.

class InMemoryDataset(root: Optional[str] = None, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, log: bool = True)[source]

Dataset base class for creating graph datasets which easily fit into CPU memory. Inherits from torch_geometric.data.Dataset. See here for the accompanying tutorial.

Parameters
  • root (string, optional) – Root directory where the dataset should be saved. (default: None)

  • transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before every access. (default: None)

  • pre_transform (callable, optional) – A function/transform that takes in an torch_geometric.data.Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: None)

  • pre_filter (callable, optional) – A function that takes in an torch_geometric.data.Data object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: None)

  • log (bool, optional) – Whether to print any console output while downloading and processing the dataset. (default: True)

property raw_file_names: Union[str, List[str], Tuple]

The name of the files in the self.raw_dir folder that must be present in order to skip downloading.

property processed_file_names: Union[str, List[str], Tuple]

The name of the files in the self.processed_dir folder that must be present in order to skip processing.

property num_classes: int

Returns the number of classes in the dataset.

len() int[source]

Returns the number of graphs stored in the dataset.

get(idx: int) Data[source]

Gets the data object at index idx.

static collate(data_list: List[Data]) Tuple[Data, Optional[Dict[str, Tensor]]][source]

Collates a Python list of torch_geometric.data.Data objects to the internal storage format of InMemoryDataset.

copy(idx: Optional[Union[slice, Tensor, ndarray, Sequence]] = None) InMemoryDataset[source]

Performs a deep-copy of the dataset. If idx is not given, will clone the full dataset. Otherwise, will only clone a subset of the dataset from indices idx. Indices can be slices, lists, tuples, and a torch.Tensor or np.ndarray of type long or bool.

class LightningDataset(train_dataset: Dataset, val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, batch_size: int = 1, num_workers: int = 0, **kwargs)[source]

Converts a set of Dataset objects into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU graph-level training via PyTorch Lightning. LightningDataset will take care of providing mini-batches via DataLoader.

Note

Currently only the pytorch_lightning.strategies.SingleDeviceStrategy and pytorch_lightning.strategies.DDPSpawnStrategy training strategies of PyTorch Lightning are supported in order to correctly share data across all devices/processes:

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
                     devices=4)
trainer.fit(model, datamodule)
Parameters
  • train_dataset (Dataset) – The training dataset.

  • val_dataset (Dataset, optional) – The validation dataset. (default: None)

  • test_dataset (Dataset, optional) – The test dataset. (default: None)

  • batch_size (int, optional) – How many samples per batch to load. (default: 1)

  • num_workers – How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

  • **kwargs (optional) – Additional arguments of torch_geometric.loader.DataLoader.

class LightningNodeData(data: Union[Data, HeteroData], input_train_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_val_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_test_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_pred_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, loader: str = 'neighbor', node_sampler: Optional[BaseSampler] = None, batch_size: int = 1, num_workers: int = 0, **kwargs)[source]

Converts a Data or HeteroData object into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU node-level training via PyTorch Lightning. LightningDataset will take care of providing mini-batches via NeighborLoader.

Note

Currently only the pytorch_lightning.strategies.SingleDeviceStrategy and pytorch_lightning.strategies.DDPSpawnStrategy training strategies of PyTorch Lightning are supported in order to correctly share data across all devices/processes:

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
                     devices=4)
trainer.fit(model, datamodule)
Parameters
  • data (Data or HeteroData) – The Data or HeteroData graph object.

  • input_train_nodes (torch.Tensor or str or (str, torch.Tensor)) – The indices of training nodes. If not given, will try to automatically infer them from the data object by searching for train_mask, train_idx, or train_index attributes. (default: None)

  • input_val_nodes (torch.Tensor or str or (str, torch.Tensor)) – The indices of validation nodes. If not given, will try to automatically infer them from the data object by searching for val_mask, valid_mask, val_idx, valid_idx, val_index, or valid_index attributes. (default: None)

  • input_test_nodes (torch.Tensor or str or (str, torch.Tensor)) – The indices of test nodes. If not given, will try to automatically infer them from the data object by searching for test_mask, test_idx, or test_index attributes. (default: None)

  • input_pred_nodes (torch.Tensor or str or (str, torch.Tensor)) – The indices of prediction nodes. If not given, will try to automatically infer them from the data object by searching for pred_mask, pred_idx, or pred_index attributes. (default: None)

  • loader (str) – The scalability technique to use ("full", "neighbor"). (default: "neighbor")

  • node_sampler (BaseSampler, optional) – A custom sampler object to generate mini-batches. If set, will ignore the loader option. (default: None)

  • batch_size (int, optional) – How many samples per batch to load. (default: 1)

  • num_workers – How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

  • **kwargs (optional) – Additional arguments of torch_geometric.loader.NeighborLoader.

class LightningLinkData(data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], input_train_edges: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]] = None, input_train_labels: Optional[Tensor] = None, input_train_time: Optional[Tensor] = None, input_val_edges: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]] = None, input_val_labels: Optional[Tensor] = None, input_val_time: Optional[Tensor] = None, input_test_edges: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]] = None, input_test_labels: Optional[Tensor] = None, input_test_time: Optional[Tensor] = None, loader: str = 'neighbor', link_sampler: Optional[BaseSampler] = None, batch_size: int = 1, num_workers: int = 0, **kwargs)[source]

Converts a Data or HeteroData object into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU link-level training (such as for link prediction) via PyTorch Lightning. LightningDataset will take care of providing mini-batches via LinkNeighborLoader.

Note

Currently only the pytorch_lightning.strategies.SingleDeviceStrategy and pytorch_lightning.strategies.DDPSpawnStrategy training strategies of PyTorch Lightning are supported in order to correctly share data across all devices/processes:

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
                     devices=4)
trainer.fit(model, datamodule)
Parameters
  • data (Data or HeteroData or Tuple[FeatureStore, GraphStore]) – The Data or HeteroData graph object, or a tuple of a FeatureStore and GraphStore objects.

  • input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]) – The training edges. (default: None)

  • input_train_labels (Tensor, optional) – The labels of train edges. (default: None)

  • input_train_time (Tensor, optional) – The timestamp of train edges. (default: None)

  • input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]) – The validation edges. (default: None)

  • input_val_labels (Tensor, optional) – The labels of validation edges. (default: None)

  • input_val_time (Tensor, optional) – The timestamp of validation edges. (default: None)

  • input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]) – The test edges. (default: None)

  • input_test_labels (Tensor, optional) – The labels of train edges. (default: None)

  • input_test_time (Tensor, optional) – The timestamp of test edges. (default: None)

  • loader (str) – The scalability technique to use ("full", "neighbor"). (default: "neighbor")

  • link_sampler (BaseSampler, optional) – A custom sampler object to generate mini-batches. If set, will ignore the loader option. (default: None)

  • batch_size (int, optional) – How many samples per batch to load. (default: 1)

  • num_workers – How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

  • **kwargs (optional) – Additional arguments of torch_geometric.loader.LinkNeighborLoader.

makedirs(path: str)[source]

Recursive directory creation function.

download_url(url: str, folder: str, log: bool = True, filename: Optional[str] = None)[source]

Downloads the content of an URL to a specific folder.

Parameters
  • url (string) – The url.

  • folder (string) – The folder.

  • log (bool, optional) – If False, will not print anything to the console. (default: True)

extract_tar(path: str, folder: str, mode: str = 'r:gz', log: bool = True)[source]

Extracts a tar archive to a specific folder.

Parameters
  • path (string) – The path to the tar archive.

  • folder (string) – The folder.

  • mode (string, optional) – The compression mode. (default: "r:gz")

  • log (bool, optional) – If False, will not print anything to the console. (default: True)

extract_zip(path: str, folder: str, log: bool = True)[source]

Extracts a zip archive to a specific folder.

Parameters
  • path (string) – The path to the tar archive.

  • folder (string) – The folder.

  • log (bool, optional) – If False, will not print anything to the console. (default: True)

extract_bz2(path: str, folder: str, log: bool = True)[source]

Extracts a bz2 archive to a specific folder.

Parameters
  • path (string) – The path to the tar archive.

  • folder (string) – The folder.

  • log (bool, optional) – If False, will not print anything to the console. (default: True)

extract_gz(path: str, folder: str, log: bool = True)[source]

Extracts a gz archive to a specific folder.

Parameters
  • path (string) – The path to the tar archive.

  • folder (string) – The folder.

  • log (bool, optional) – If False, will not print anything to the console. (default: True)