Source code for

from typing import Optional, List, Union, Any

import inspect
from import Sequence

import torch
import numpy as np
from torch import Tensor

from import Data, HeteroData
from import collate
from import separate
from import IndexType

class DynamicInheritance(type):
    # A meta class that sets the base class of a `Batch` object, e.g.:
    # * `Batch(Data)` in case `Data` objects are batched together
    # * `Batch(HeteroData)` in case `HeteroData` objects are batched together
    def __call__(cls, *args, **kwargs):
        base_cls = kwargs.pop('_base_cls', Data)

        if issubclass(base_cls, Batch):
            new_cls = base_cls
            name = f'{base_cls.__name__}{cls.__name__}'
            if name not in globals():
                globals()[name] = type(name, (cls, base_cls), {})
            new_cls = globals()[name]

        params = list(inspect.signature(base_cls.__init__).parameters.items())
        for i, (k, v) in enumerate(params[1:]):
            if k == 'args' or k == 'kwargs':
            if i < len(args) or k in kwargs:
            if v.default is not inspect.Parameter.empty:
            kwargs[k] = None

        return super(DynamicInheritance, new_cls).__call__(*args, **kwargs)

class DynamicInheritanceGetter(object):
    def __call__(self, cls, base_cls):
        return cls(_base_cls=base_cls)

[docs]class Batch(metaclass=DynamicInheritance): r"""A data object describing a batch of graphs as one big (disconnected) graph. Inherits from :class:`` or :class:``. In addition, single graphs can be identified via the assignment vector :obj:`batch`, which maps each node to its respective graph identifier. """
[docs] @classmethod def from_data_list(cls, data_list: Union[List[Data], List[HeteroData]], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None): r"""Constructs a :class:`` object from a Python list of :class:`` or :class:`` objects. The assignment vector :obj:`batch` is created on the fly. In addition, creates assignment vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`.""" batch, slice_dict, inc_dict = collate( cls, data_list=data_list, increment=True, add_batch=not isinstance(data_list[0], Batch), follow_batch=follow_batch, exclude_keys=exclude_keys, ) batch._num_graphs = len(data_list) batch._slice_dict = slice_dict batch._inc_dict = inc_dict return batch
[docs] def get_example(self, idx: int) -> Union[Data, HeteroData]: r"""Gets the :class:`` or :class:`` object at index :obj:`idx`. The :class:`` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial object.""" if not hasattr(self, '_slice_dict'): raise RuntimeError( ("Cannot reconstruct 'Data' object from 'Batch' because " "'Batch' was not created via 'Batch.from_data_list()'")) data = separate( cls=self.__class__.__bases__[-1], batch=self, idx=idx, slice_dict=self._slice_dict, inc_dict=self._inc_dict, decrement=True, ) return data
[docs] def index_select(self, idx: IndexType) -> Union[List[Data], List[HeteroData]]: r"""Creates a subset of :class:`` or :class:`` objects from specified indices :obj:`idx`. Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool. The :class:`` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial objects.""" if isinstance(idx, slice): idx = list(range(self.num_graphs)[idx]) elif isinstance(idx, Tensor) and idx.dtype == torch.long: idx = idx.flatten().tolist() elif isinstance(idx, Tensor) and idx.dtype == torch.bool: idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: idx = idx.flatten().tolist() elif isinstance(idx, np.ndarray) and idx.dtype == bool: idx = idx.flatten().nonzero()[0].flatten().tolist() elif isinstance(idx, Sequence) and not isinstance(idx, str): pass else: raise IndexError( f"Only slices (':'), list, tuples, torch.tensor and " f"np.ndarray of dtype long or bool are valid indices (got " f"'{type(idx).__name__}')") return [self.get_example(i) for i in idx]
def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any: if (isinstance(idx, (int, np.integer)) or (isinstance(idx, Tensor) and idx.dim() == 0) or (isinstance(idx, np.ndarray) and np.isscalar(idx))): return self.get_example(idx) elif isinstance(idx, str) or (isinstance(idx, tuple) and isinstance(idx[0], str)): # Accessing attributes or node/edge types: return super().__getitem__(idx) else: return self.index_select(idx)
[docs] def to_data_list(self) -> Union[List[Data], List[HeteroData]]: r"""Reconstructs the list of :class:`` or :class:`` objects from the :class:`` object. The :class:`` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial objects.""" return [self.get_example(i) for i in range(self.num_graphs)]
@property def num_graphs(self) -> int: """Returns the number of graphs in the batch.""" if hasattr(self, '_num_graphs'): return self._num_graphs elif hasattr(self, 'ptr'): return self.ptr.numel() - 1 elif hasattr(self, 'batch'): return int(self.batch.max()) + 1 else: raise ValueError("Can not infer the number of graphs") def __reduce__(self): state = self.__dict__.copy() return DynamicInheritanceGetter(), self.__class__.__bases__, state