import copy
from collections.abc import Mapping, Sequence
from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Optional,
Tuple, Union)
import numpy as np
import torch
from torch import Tensor
from torch_sparse import SparseTensor
from torch_geometric.data.storage import (BaseStorage, EdgeStorage,
GlobalStorage, NodeStorage)
from torch_geometric.deprecation import deprecated
from torch_geometric.typing import EdgeType, NodeType, OptTensor
from torch_geometric.utils import subgraph
class BaseData(object):
def __getattr__(self, key: str) -> Any:
raise NotImplementedError
def __setattr__(self, key: str, value: Any):
raise NotImplementedError
def __delattr__(self, key: str):
raise NotImplementedError
def __getitem__(self, key: str) -> Any:
raise NotImplementedError
def __setitem__(self, key: str, value: Any):
raise NotImplementedError
def __delitem__(self, key: str):
raise NotImplementedError
def __copy__(self):
raise NotImplementedError
def __deepcopy__(self, memo):
raise NotImplementedError
def __repr__(self) -> str:
raise NotImplementedError
def stores_as(self, data: 'BaseData'):
raise NotImplementedError
@property
def stores(self) -> List[BaseStorage]:
raise NotImplementedError
@property
def node_stores(self) -> List[NodeStorage]:
raise NotImplementedError
@property
def edge_stores(self) -> List[EdgeStorage]:
raise NotImplementedError
def to_dict(self) -> Dict[str, Any]:
r"""Returns a dictionary of stored key/value pairs."""
raise NotImplementedError
def to_namedtuple(self) -> NamedTuple:
r"""Returns a :obj:`NamedTuple` of stored key/value pairs."""
raise NotImplementedError
def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
r"""Returns the dimension for which the value :obj:`value` of the
attribute :obj:`key` will get concatenated when creating mini-batches
using :class:`torch_geometric.loader.DataLoader`.
.. note::
This method is for internal use only, and should only be overridden
in case the mini-batch creation process is corrupted for a specific
attribute.
"""
raise NotImplementedError
def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
r"""Returns the incremental count to cumulatively increase the value
:obj:`value` of the attribute :obj:`key` when creating mini-batches
using :class:`torch_geometric.loader.DataLoader`.
.. note::
This method is for internal use only, and should only be overridden
in case the mini-batch creation process is corrupted for a specific
attribute.
"""
raise NotImplementedError
def debug(self):
raise NotImplementedError
###########################################################################
@property
def keys(self) -> List[str]:
r"""Returns a list of all graph attribute names."""
out = []
for store in self.stores:
out += list(store.keys())
return list(set(out))
def __len__(self) -> int:
r"""Returns the number of graph attributes."""
return len(self.keys)
def __contains__(self, key: str) -> bool:
r"""Returns :obj:`True` if the attribute :obj:`key` is present in the
data."""
return key in self.keys
def __getstate__(self) -> Dict[str, Any]:
return self.__dict__
def __setstate__(self, mapping: Dict[str, Any]):
for key, value in mapping.items():
self.__dict__[key] = value
@property
def num_nodes(self) -> Optional[int]:
r"""Returns the number of nodes in the graph.
.. note::
The number of nodes in the data object is automatically inferred
in case node-level attributes are present, *e.g.*, :obj:`data.x`.
In some cases, however, a graph may only be given without any
node-level attributes.
PyG then *guesses* the number of nodes according to
:obj:`edge_index.max().item() + 1`.
However, in case there exists isolated nodes, this number does not
have to be correct which can result in unexpected behaviour.
Thus, we recommend to set the number of nodes in your data object
explicitly via :obj:`data.num_nodes = ...`.
You will be given a warning that requests you to do so.
"""
try:
return sum([v.num_nodes for v in self.node_stores])
except TypeError:
return None
def size(
self, dim: Optional[int] = None
) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]:
r"""Returns the size of the adjacency matrix induced by the graph."""
size = (self.num_nodes, self.num_nodes)
return size if dim is None else size[dim]
@property
def num_edges(self) -> int:
r"""Returns the number of edges in the graph.
For undirected graphs, this will return the number of bi-directional
edges, which is double the amount of unique edges."""
return sum([v.num_edges for v in self.edge_stores])
def is_coalesced(self) -> bool:
r"""Returns :obj:`True` if edge indices :obj:`edge_index` are sorted
and do not contain duplicate entries."""
return all([store.is_coalesced() for store in self.edge_stores])
def coalesce(self):
r"""Sorts and removes duplicated entries from edge indices
:obj:`edge_index`."""
for store in self.edge_stores:
store.coalesce()
return self
def has_isolated_nodes(self) -> bool:
r"""Returns :obj:`True` if the graph contains isolated nodes."""
return any([store.has_isolated_nodes() for store in self.edge_stores])
def has_self_loops(self) -> bool:
"""Returns :obj:`True` if the graph contains self-loops."""
return any([store.has_self_loops() for store in self.edge_stores])
def is_undirected(self) -> bool:
r"""Returns :obj:`True` if graph edges are undirected."""
return all([store.is_undirected() for store in self.edge_stores])
def is_directed(self) -> bool:
r"""Returns :obj:`True` if graph edges are directed."""
return not self.is_undirected()
def apply_(self, func: Callable, *args: List[str]):
r"""Applies the in-place function :obj:`func`, either to all attributes
or only the ones given in :obj:`*args`."""
for store in self.stores:
store.apply_(func, *args)
return self
def apply(self, func: Callable, *args: List[str]):
r"""Applies the function :obj:`func`, either to all attributes or only
the ones given in :obj:`*args`."""
for store in self.stores:
store.apply(func, *args)
return self
def clone(self, *args: List[str]):
r"""Performs cloning of tensors, either for all attributes or only the
ones given in :obj:`*args`."""
return copy.copy(self).apply(lambda x: x.clone(), *args)
def contiguous(self, *args: List[str]):
r"""Ensures a contiguous memory layout, either for all attributes or
only the ones given in :obj:`*args`."""
return self.apply(lambda x: x.contiguous(), *args)
def to(self, device: Union[int, str], *args: List[str],
non_blocking: bool = False):
r"""Performs tensor device conversion, either for all attributes or
only the ones given in :obj:`*args`."""
return self.apply(
lambda x: x.to(device=device, non_blocking=non_blocking), *args)
def cpu(self, *args: List[str]):
r"""Copies attributes to CPU memory, either for all attributes or only
the ones given in :obj:`*args`."""
return self.apply(lambda x: x.cpu(), *args)
def cuda(self, device: Optional[Union[int, str]] = None, *args: List[str],
non_blocking: bool = False):
r"""Copies attributes to CUDA memory, either for all attributes or only
the ones given in :obj:`*args`."""
# Some PyTorch tensor like objects require a default value for `cuda`:
device = 'cuda' if device is None else device
return self.apply(lambda x: x.cuda(device, non_blocking=non_blocking),
*args)
def pin_memory(self, *args: List[str]):
r"""Copies attributes to pinned memory, either for all attributes or
only the ones given in :obj:`*args`."""
return self.apply(lambda x: x.pin_memory(), *args)
def share_memory_(self, *args: List[str]):
r"""Moves attributes to shared memory, either for all attributes or
only the ones given in :obj:`*args`."""
return self.apply_(lambda x: x.share_memory_(), *args)
def detach_(self, *args: List[str]):
r"""Detaches attributes from the computation graph, either for all
attributes or only the ones given in :obj:`*args`."""
return self.apply_(lambda x: x.detach_(), *args)
def detach(self, *args: List[str]):
r"""Detaches attributes from the computation graph by creating a new
tensor, either for all attributes or only the ones given in
:obj:`*args`."""
return self.apply(lambda x: x.detach(), *args)
def requires_grad_(self, *args: List[str], requires_grad: bool = True):
r"""Tracks gradient computation, either for all attributes or only the
ones given in :obj:`*args`."""
return self.apply_(
lambda x: x.requires_grad_(requires_grad=requires_grad), *args)
def record_stream(self, stream: torch.cuda.Stream, *args: List[str]):
r"""Ensures that the tensor memory is not reused for another tensor
until all current work queued on :obj:`stream` has been completed,
either for all attributes or only the ones given in :obj:`*args`."""
return self.apply_(lambda x: x.record_stream(stream), *args)
@property
def is_cuda(self) -> bool:
r"""Returns :obj:`True` if any :class:`torch.Tensor` attribute is
stored on the GPU, :obj:`False` otherwise."""
for store in self.stores:
for value in store.values():
if isinstance(value, Tensor) and value.is_cuda:
return True
return False
# Deprecated functions ####################################################
@deprecated(details="use 'has_isolated_nodes' instead")
def contains_isolated_nodes(self) -> bool:
return self.has_isolated_nodes()
@deprecated(details="use 'has_self_loops' instead")
def contains_self_loops(self) -> bool:
return self.has_self_loops()
###############################################################################
[docs]class Data(BaseData):
r"""A data object describing a homogeneous graph.
The data object can hold node-level, link-level and graph-level attributes.
In general, :class:`~torch_geometric.data.Data` tries to mimic the
behaviour of a regular Python dictionary.
In addition, it provides useful functionality for analyzing graph
structures, and provides basic PyTorch tensor functionalities.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
introduction.html#data-handling-of-graphs>`__ for the accompanying
tutorial.
.. code-block:: python
from torch_geometric.data import Data
data = Data(x=x, edge_index=edge_index, ...)
# Add additional arguments to `data`:
data.train_idx = torch.tensor([...], dtype=torch.long)
data.test_mask = torch.tensor([...], dtype=torch.bool)
# Analyzing the graph structure:
data.num_nodes
>>> 23
data.is_directed()
>>> False
# PyTorch tensor functionality:
data = data.pin_memory()
data = data.to('cuda:0', non_blocking=True)
Args:
x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes,
num_node_features]`. (default: :obj:`None`)
edge_index (LongTensor, optional): Graph connectivity in COO format
with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
edge_attr (Tensor, optional): Edge feature matrix with shape
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
y (Tensor, optional): Graph-level or node-level ground-truth labels
with arbitrary shape. (default: :obj:`None`)
pos (Tensor, optional): Node position matrix with shape
:obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
**kwargs (optional): Additional attributes.
"""
def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
edge_attr: OptTensor = None, y: OptTensor = None,
pos: OptTensor = None, **kwargs):
super().__init__()
self.__dict__['_store'] = GlobalStorage(_parent=self)
if x is not None:
self.x = x
if edge_index is not None:
self.edge_index = edge_index
if edge_attr is not None:
self.edge_attr = edge_attr
if y is not None:
self.y = y
if pos is not None:
self.pos = pos
for key, value in kwargs.items():
setattr(self, key, value)
def __getattr__(self, key: str) -> Any:
if '_store' not in self.__dict__:
raise RuntimeError(
"The 'data' object was created by an older version of PyG. "
"If this error occurred while loading an already existing "
"dataset, remove the 'processed/' directory in the dataset's "
"root folder and try again.")
return getattr(self._store, key)
def __setattr__(self, key: str, value: Any):
setattr(self._store, key, value)
def __delattr__(self, key: str):
delattr(self._store, key)
def __getitem__(self, key: str) -> Any:
return self._store[key]
def __setitem__(self, key: str, value: Any):
self._store[key] = value
def __delitem__(self, key: str):
if key in self._store:
del self._store[key]
def __copy__(self):
out = self.__class__.__new__(self.__class__)
for key, value in self.__dict__.items():
out.__dict__[key] = value
out.__dict__['_store'] = copy.copy(self._store)
out._store._parent = out
return out
def __deepcopy__(self, memo):
out = self.__class__.__new__(self.__class__)
for key, value in self.__dict__.items():
out.__dict__[key] = copy.deepcopy(value, memo)
out._store._parent = out
return out
def __repr__(self) -> str:
cls = self.__class__.__name__
has_dict = any([isinstance(v, Mapping) for v in self._store.values()])
if not has_dict:
info = [size_repr(k, v) for k, v in self._store.items()]
info = ', '.join(info)
return f'{cls}({info})'
else:
info = [size_repr(k, v, indent=2) for k, v in self._store.items()]
info = ',\n'.join(info)
return f'{cls}(\n{info}\n)'
def stores_as(self, data: 'Data'):
return self
@property
def stores(self) -> List[BaseStorage]:
return [self._store]
@property
def node_stores(self) -> List[NodeStorage]:
return [self._store]
@property
def edge_stores(self) -> List[EdgeStorage]:
return [self._store]
[docs] def to_dict(self) -> Dict[str, Any]:
return self._store.to_dict()
[docs] def to_namedtuple(self) -> NamedTuple:
return self._store.to_namedtuple()
[docs] def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
if isinstance(value, SparseTensor) and 'adj' in key:
return (0, 1)
elif 'index' in key or 'face' in key:
return -1
else:
return 0
[docs] def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
if 'batch' in key:
return int(value.max()) + 1
elif 'index' in key or 'face' in key:
return self.num_nodes
else:
return 0
def debug(self):
pass # TODO
[docs] def is_node_attr(self, key: str) -> bool:
r"""Returns :obj:`True` if the object at key :obj:`key` denotes a
node-level attribute."""
return self._store.is_node_attr(key)
[docs] def is_edge_attr(self, key: str) -> bool:
r"""Returns :obj:`True` if the object at key :obj:`key` denotes an
edge-level attribute."""
return self._store.is_edge_attr(key)
[docs] def subgraph(self, subset: Tensor):
r"""Returns the induced subgraph given by the node indices
:obj:`subset`.
Args:
subset (LongTensor or BoolTensor): The nodes to keep.
"""
out = subgraph(subset, self.edge_index, relabel_nodes=True,
num_nodes=self.num_nodes, return_edge_mask=True)
edge_index, _, edge_mask = out
if subset.dtype == torch.bool:
num_nodes = int(subset.sum())
else:
num_nodes = subset.size(0)
data = copy.copy(self)
for key, value in data:
if key == 'edge_index':
data.edge_index = edge_index
elif key == 'num_nodes':
data.num_nodes = num_nodes
elif isinstance(value, Tensor):
if self.is_node_attr(key):
data[key] = value[subset]
elif self.is_edge_attr(key):
data[key] = value[edge_mask]
return data
[docs] def to_heterogeneous(self, node_type: Optional[Tensor] = None,
edge_type: Optional[Tensor] = None,
node_type_names: Optional[List[NodeType]] = None,
edge_type_names: Optional[List[EdgeType]] = None):
r"""Converts a :class:`~torch_geometric.data.Data` object to a
heterogeneous :class:`~torch_geometric.data.HeteroData` object.
For this, node and edge attributes are splitted according to the
node-level and edge-level vectors :obj:`node_type` and
:obj:`edge_type`, respectively.
:obj:`node_type_names` and :obj:`edge_type_names` can be used to give
meaningful node and edge type names, respectively.
That is, the node_type :obj:`0` is given by :obj:`node_type_names[0]`.
If the :class:`~torch_geometric.data.Data` object was constructed via
:meth:`~torch_geometric.data.HeteroData.to_homogeneous`, the object can
be reconstructed without any need to pass in additional arguments.
Args:
node_type (Tensor, optional): A node-level vector denoting the type
of each node. (default: :obj:`None`)
edge_type (Tensor, optional): An edge-level vector denoting the
type of each edge. (default: :obj:`None`)
node_type_names (List[str], optional): The names of node types.
(default: :obj:`None`)
edge_type_names (List[Tuple[str, str, str]], optional): The names
of edge types. (default: :obj:`None`)
"""
from torch_geometric.data import HeteroData
if node_type is None:
node_type = self._store.get('node_type', None)
if node_type is None:
node_type = torch.zeros(self.num_nodes, dtype=torch.long)
if node_type_names is None:
store = self._store
node_type_names = store.__dict__.get('_node_type_names', None)
if node_type_names is None:
node_type_names = [str(i) for i in node_type.unique().tolist()]
if edge_type is None:
edge_type = self._store.get('edge_type', None)
if edge_type is None:
edge_type = torch.zeros(self.num_edges, dtype=torch.long)
if edge_type_names is None:
store = self._store
edge_type_names = store.__dict__.get('_edge_type_names', None)
if edge_type_names is None:
edge_type_names = []
edge_index = self.edge_index
for i in edge_type.unique().tolist():
src, dst = edge_index[:, edge_type == i]
src_types = node_type[src].unique().tolist()
dst_types = node_type[dst].unique().tolist()
if len(src_types) != 1 and len(dst_types) != 1:
raise ValueError(
"Could not construct a 'HeteroData' object from the "
"'Data' object because single edge types span over "
"multiple node types")
edge_type_names.append((node_type_names[src_types[0]], str(i),
node_type_names[dst_types[0]]))
# We iterate over node types to find the local node indices belonging
# to each node type. Furthermore, we create a global `index_map` vector
# that maps global node indices to local ones in the final
# heterogeneous graph:
node_ids, index_map = {}, torch.empty_like(node_type)
for i, key in enumerate(node_type_names):
node_ids[i] = (node_type == i).nonzero(as_tuple=False).view(-1)
index_map[node_ids[i]] = torch.arange(len(node_ids[i]))
# We iterate over edge types to find the local edge indices:
edge_ids = {}
for i, key in enumerate(edge_type_names):
edge_ids[i] = (edge_type == i).nonzero(as_tuple=False).view(-1)
data = HeteroData()
for i, key in enumerate(node_type_names):
for attr, value in self.items():
if attr == 'node_type' or attr == 'edge_type':
continue
elif isinstance(value, Tensor) and self.is_node_attr(attr):
data[key][attr] = value[node_ids[i]]
if len(data[key]) == 0:
data[key].num_nodes = node_ids[i].size(0)
for i, key in enumerate(edge_type_names):
src, _, dst = key
for attr, value in self.items():
if attr == 'node_type' or attr == 'edge_type':
continue
elif attr == 'edge_index':
edge_index = value[:, edge_ids[i]]
edge_index[0] = index_map[edge_index[0]]
edge_index[1] = index_map[edge_index[1]]
data[key].edge_index = edge_index
elif isinstance(value, Tensor) and self.is_edge_attr(attr):
data[key][attr] = value[edge_ids[i]]
# Add global attributes.
keys = set(data.keys) | {'node_type', 'edge_type', 'num_nodes'}
for attr, value in self.items():
if attr in keys:
continue
if len(data.node_stores) == 1:
data.node_stores[0][attr] = value
else:
data[attr] = value
return data
###########################################################################
[docs] @classmethod
def from_dict(cls, mapping: Dict[str, Any]):
r"""Creates a :class:`~torch_geometric.data.Data` object from a Python
dictionary."""
return cls(**mapping)
@property
def num_node_features(self) -> int:
r"""Returns the number of features per node in the graph."""
return self._store.num_node_features
@property
def num_features(self) -> int:
r"""Returns the number of features per node in the graph.
Alias for :py:attr:`~num_node_features`."""
return self._store.num_features
@property
def num_edge_features(self) -> int:
r"""Returns the number of features per edge in the graph."""
return self._store.num_edge_features
def __iter__(self) -> Iterable:
r"""Iterates over all attributes in the data, yielding their attribute
names and values."""
for key, value in self._store.items():
yield key, value
def __call__(self, *args: List[str]) -> Iterable:
r"""Iterates over all attributes :obj:`*args` in the data, yielding
their attribute names and values.
If :obj:`*args` is not given, will iterate over all attributes."""
for key, value in self._store.items(*args):
yield key, value
@property
def x(self) -> Any:
return self['x'] if 'x' in self._store else None
@property
def edge_index(self) -> Any:
return self['edge_index'] if 'edge_index' in self._store else None
@property
def edge_weight(self) -> Any:
return self['edge_weight'] if 'edge_weight' in self._store else None
@property
def edge_attr(self) -> Any:
return self['edge_attr'] if 'edge_attr' in self._store else None
@property
def y(self) -> Any:
return self['y'] if 'y' in self._store else None
@property
def pos(self) -> Any:
return self['pos'] if 'pos' in self._store else None
@property
def batch(self) -> Any:
return self['batch'] if 'batch' in self._store else None
# Deprecated functions ####################################################
@property
@deprecated(details="use 'data.face.size(-1)' instead")
def num_faces(self) -> Optional[int]:
r"""Returns the number of faces in the mesh."""
if 'face' in self._store and isinstance(self.face, Tensor):
return self.face.size(self.__cat_dim__('face', self.face))
return None
###############################################################################
def size_repr(key: Any, value: Any, indent: int = 0) -> str:
pad = ' ' * indent
if isinstance(value, Tensor) and value.dim() == 0:
out = value.item()
elif isinstance(value, Tensor):
out = str(list(value.size()))
elif isinstance(value, np.ndarray):
out = str(list(value.shape))
elif isinstance(value, SparseTensor):
out = str(value.sizes())[:-1] + f', nnz={value.nnz()}]'
elif isinstance(value, str):
out = f"'{value}'"
elif isinstance(value, Sequence):
out = str([len(value)])
elif isinstance(value, Mapping) and len(value) == 0:
out = '{}'
elif (isinstance(value, Mapping) and len(value) == 1
and not isinstance(list(value.values())[0], Mapping)):
lines = [size_repr(k, v, 0) for k, v in value.items()]
out = '{ ' + ', '.join(lines) + ' }'
elif isinstance(value, Mapping):
lines = [size_repr(k, v, indent + 2) for k, v in value.items()]
out = '{\n' + ',\n'.join(lines) + '\n' + pad + '}'
else:
out = str(value)
key = str(key).replace("'", '')
if isinstance(value, BaseStorage):
return f'{pad}\033[1m{key}\033[0m={out}'
else:
return f'{pad}{key}={out}'