Source code for torch_geometric.data.batch

from typing import Optional, List, Union, Any

from collections.abc import Sequence

import torch
import numpy as np
from torch import Tensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.collate import collate
from torch_geometric.data.separate import separate
from torch_geometric.data.dataset import IndexType


class DynamicInheritance(type):
    # A meta class that sets the base class of a `Batch` object, e.g.:
    # * `Batch(Data)` in case `Data` objects are batched together
    # * `Batch(HeteroData)` in case `HeteroData` objects are batched together
    def __call__(cls, *args, **kwargs):
        base_cls = kwargs.pop('_base_cls', Data)
        new_cls = type(cls.__name__, (cls, base_cls), {})
        return super(DynamicInheritance, new_cls).__call__(*args, **kwargs)


[docs]class Batch(metaclass=DynamicInheritance): r"""A data object describing a batch of graphs as one big (disconnected) graph. Inherits from :class:`torch_geometric.data.Data` or :class:`torch_geometric.data.HeteroData`. In addition, single graphs can be identified via the assignment vector :obj:`batch`, which maps each node to its respective graph identifier. """
[docs] @classmethod def from_data_list(cls, data_list: Union[List[Data], List[HeteroData]], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None): r"""Constructs a :class:`~torch_geometric.data.Batch` object from a Python list of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects. The assignment vector :obj:`batch` is created on the fly. In addition, creates assignment vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`.""" batch, slice_dict, inc_dict = collate( cls, data_list=data_list, increment=True, add_batch=True, follow_batch=follow_batch, exclude_keys=exclude_keys, ) batch._num_graphs = len(data_list) batch._slice_dict = slice_dict batch._inc_dict = inc_dict return batch
[docs] def get(self, idx: int) -> Union[Data, HeteroData]: r"""Gets the :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object at index :obj:`idx`. The :class:`~torch_geometric.data.Batch` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial object.""" if not hasattr(self, '_slice_dict'): raise RuntimeError( ("Cannot reconstruct 'Data' object from 'Batch' because " "'Batch' was not created via 'Batch.from_data_list()'")) data = separate( cls=self.__class__.__bases__[-1], batch=self, idx=idx, slice_dict=self._slice_dict, inc_dict=self._inc_dict, decrement=True, ) return data
[docs] def index_select(self, idx: IndexType) -> Union[List[Data], List[HeteroData]]: r"""Creates a subset of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects from specified indices :obj:`idx`. Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool. The :class:`~torch_geometric.data.Batch` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial objects.""" if isinstance(idx, slice): idx = list(range(self.num_graphs)[idx]) elif isinstance(idx, Tensor) and idx.dtype == torch.long: idx = idx.flatten().tolist() elif isinstance(idx, Tensor) and idx.dtype == torch.bool: idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: idx = idx.flatten().tolist() elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: idx = idx.flatten().nonzero()[0].flatten().tolist() elif isinstance(idx, Sequence) and not isinstance(idx, str): pass else: raise IndexError( f"Only slices (':'), list, tuples, torch.tensor and " f"np.ndarray of dtype long or bool are valid indices (got " f"'{type(idx).__name__}')") return [self.get(i) for i in idx]
def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any: if (isinstance(idx, (int, np.integer)) or (isinstance(idx, Tensor) and idx.dim() == 0) or (isinstance(idx, np.ndarray) and np.isscalar(idx))): return self.get(idx) elif isinstance(idx, str) or (isinstance(idx, tuple) and isinstance(idx[0], str)): # Accessing attributes or node/edge types: return super().__getitem__(idx) else: return self.index_select(idx)
[docs] def to_data_list(self) -> Union[List[Data], List[HeteroData]]: r"""Reconstructs the list of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects from the :class:`~torch_geometric.data.Batch` object. The :class:`~torch_geometric.data.Batch` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial objects.""" return [self.get(i) for i in range(self.num_graphs)]
@property def num_graphs(self) -> int: """Returns the number of graphs in the batch.""" if hasattr(self, '_num_graphs'): return self._num_graphs elif hasattr(self, 'ptr'): return self.ptr.numel() - 1 elif hasattr(self, 'batch'): return int(self.batch.max()) + 1 else: raise ValueError("Can not infer the number of graphs")