torch_geometric.data.Batch
- class Batch(*args, **kwargs)[source]
Bases:
object
A data object describing a batch of graphs as one big (disconnected) graph. Inherits from
torch_geometric.data.Data
ortorch_geometric.data.HeteroData
. In addition, single graphs can be identified via the assignment vectorbatch
, which maps each node to its respective graph identifier.- classmethod from_data_list(data_list: List[BaseData], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None)[source]
Constructs a
Batch
object from a Python list ofData
orHeteroData
objects. The assignment vectorbatch
is created on the fly. In addition, creates assignment vectors for each key infollow_batch
. Will exclude any keys given inexclude_keys
.
- get_example(idx: int) BaseData [source]
Gets the
Data
orHeteroData
object at indexidx
. TheBatch
object must have been created viafrom_data_list()
in order to be able to reconstruct the initial object.
- index_select(idx: Union[slice, Tensor, ndarray, Sequence]) List[BaseData] [source]
Creates a subset of
Data
orHeteroData
objects from specified indicesidx
. Indicesidx
can be a slicing object, e.g.,[2:5]
, a list, a tuple, or atorch.Tensor
ornp.ndarray
of type long or bool. TheBatch
object must have been created viafrom_data_list()
in order to be able to reconstruct the initial objects.
- to_data_list() List[BaseData] [source]
Reconstructs the list of
Data
orHeteroData
objects from theBatch
object. TheBatch
object must have been created viafrom_data_list()
in order to be able to reconstruct the initial objects.