import inspect
from collections.abc import Sequence
from typing import Any, List, Optional, Union
import numpy as np
import torch
from torch import Tensor
from torch_geometric.data.collate import collate
from torch_geometric.data.data import BaseData, Data
from torch_geometric.data.dataset import IndexType
from torch_geometric.data.separate import separate
class DynamicInheritance(type):
# A meta class that sets the base class of a `Batch` object, e.g.:
# * `Batch(Data)` in case `Data` objects are batched together
# * `Batch(HeteroData)` in case `HeteroData` objects are batched together
def __call__(cls, *args, **kwargs):
base_cls = kwargs.pop('_base_cls', Data)
if issubclass(base_cls, Batch):
new_cls = base_cls
else:
name = f'{base_cls.__name__}{cls.__name__}'
# NOTE `MetaResolver` is necessary to resolve metaclass conflict
# problems between `DynamicInheritance` and the metaclass of
# `base_cls`. In particular, it creates a new common metaclass
# from the defined metaclasses.
class MetaResolver(type(cls), type(base_cls)):
pass
if name not in globals():
globals()[name] = MetaResolver(name, (cls, base_cls), {})
new_cls = globals()[name]
params = list(inspect.signature(base_cls.__init__).parameters.items())
for i, (k, v) in enumerate(params[1:]):
if k == 'args' or k == 'kwargs':
continue
if i < len(args) or k in kwargs:
continue
if v.default is not inspect.Parameter.empty:
continue
kwargs[k] = None
return super(DynamicInheritance, new_cls).__call__(*args, **kwargs)
class DynamicInheritanceGetter:
def __call__(self, cls, base_cls):
return cls(_base_cls=base_cls)
[docs]class Batch(metaclass=DynamicInheritance):
r"""A data object describing a batch of graphs as one big (disconnected)
graph.
Inherits from :class:`torch_geometric.data.Data` or
:class:`torch_geometric.data.HeteroData`.
In addition, single graphs can be identified via the assignment vector
:obj:`batch`, which maps each node to its respective graph identifier.
"""
[docs] @classmethod
def from_data_list(cls, data_list: List[BaseData],
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None):
r"""Constructs a :class:`~torch_geometric.data.Batch` object from a
Python list of :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` objects.
The assignment vector :obj:`batch` is created on the fly.
In addition, creates assignment vectors for each key in
:obj:`follow_batch`.
Will exclude any keys given in :obj:`exclude_keys`."""
batch, slice_dict, inc_dict = collate(
cls,
data_list=data_list,
increment=True,
add_batch=not isinstance(data_list[0], Batch),
follow_batch=follow_batch,
exclude_keys=exclude_keys,
)
batch._num_graphs = len(data_list)
batch._slice_dict = slice_dict
batch._inc_dict = inc_dict
return batch
[docs] def get_example(self, idx: int) -> BaseData:
r"""Gets the :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` object at index :obj:`idx`.
The :class:`~torch_geometric.data.Batch` object must have been created
via :meth:`from_data_list` in order to be able to reconstruct the
initial object."""
if not hasattr(self, '_slice_dict'):
raise RuntimeError(
("Cannot reconstruct 'Data' object from 'Batch' because "
"'Batch' was not created via 'Batch.from_data_list()'"))
data = separate(
cls=self.__class__.__bases__[-1],
batch=self,
idx=idx,
slice_dict=self._slice_dict,
inc_dict=self._inc_dict,
decrement=True,
)
return data
[docs] def index_select(self, idx: IndexType) -> List[BaseData]:
r"""Creates a subset of :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` objects from specified
indices :obj:`idx`.
Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a
list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type
long or bool.
The :class:`~torch_geometric.data.Batch` object must have been created
via :meth:`from_data_list` in order to be able to reconstruct the
initial objects."""
if isinstance(idx, slice):
idx = list(range(self.num_graphs)[idx])
elif isinstance(idx, Tensor) and idx.dtype == torch.long:
idx = idx.flatten().tolist()
elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist()
elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
idx = idx.flatten().tolist()
elif isinstance(idx, np.ndarray) and idx.dtype == bool:
idx = idx.flatten().nonzero()[0].flatten().tolist()
elif isinstance(idx, Sequence) and not isinstance(idx, str):
pass
else:
raise IndexError(
f"Only slices (':'), list, tuples, torch.tensor and "
f"np.ndarray of dtype long or bool are valid indices (got "
f"'{type(idx).__name__}')")
return [self.get_example(i) for i in idx]
def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any:
if (isinstance(idx, (int, np.integer))
or (isinstance(idx, Tensor) and idx.dim() == 0)
or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
return self.get_example(idx)
elif isinstance(idx, str) or (isinstance(idx, tuple)
and isinstance(idx[0], str)):
# Accessing attributes or node/edge types:
return super().__getitem__(idx)
else:
return self.index_select(idx)
[docs] def to_data_list(self) -> List[BaseData]:
r"""Reconstructs the list of :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` objects from the
:class:`~torch_geometric.data.Batch` object.
The :class:`~torch_geometric.data.Batch` object must have been created
via :meth:`from_data_list` in order to be able to reconstruct the
initial objects."""
return [self.get_example(i) for i in range(self.num_graphs)]
@property
def num_graphs(self) -> int:
"""Returns the number of graphs in the batch."""
if hasattr(self, '_num_graphs'):
return self._num_graphs
elif hasattr(self, 'ptr'):
return self.ptr.numel() - 1
elif hasattr(self, 'batch'):
return int(self.batch.max()) + 1
else:
raise ValueError("Can not infer the number of graphs")
def __len__(self) -> int:
return self.num_graphs
def __reduce__(self):
state = self.__dict__.copy()
return DynamicInheritanceGetter(), self.__class__.__bases__, state