torch_geometric.data.Batch
- class Batch(*args: Any, **kwargs: Any)[source]
Bases:
object
A data object describing a batch of graphs as one big (disconnected) graph. Inherits from
torch_geometric.data.Data
ortorch_geometric.data.HeteroData
. In addition, single graphs can be identified via the assignment vectorbatch
, which maps each node to its respective graph identifier.PyG allows modification to the underlying batching procedure by overwriting the
__inc__()
and__cat_dim__()
functionalities. The__inc__()
method defines the incremental count between two consecutive graph attributes. By default, PyG increments attributes by the number of nodes whenever their attribute names contain the substringindex
(for historical reasons), which comes in handy for attributes such asedge_index
ornode_index
. However, note that this may lead to unexpected behavior for attributes whose names contain the substringindex
but should not be incremented. To make sure, it is best practice to always double-check the output of batching. Furthermore,__cat_dim__()
defines in which dimension graph tensors of the same attribute should be concatenated together.- classmethod from_data_list(data_list: List[BaseData], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None) Self [source]
Constructs a
Batch
object from a list ofData
orHeteroData
objects. The assignment vectorbatch
is created on the fly. In addition, creates assignment vectors for each key infollow_batch
. Will exclude any keys given inexclude_keys
.
- get_example(idx: int) BaseData [source]
Gets the
Data
orHeteroData
object at indexidx
. TheBatch
object must have been created viafrom_data_list()
in order to be able to reconstruct the initial object.
- index_select(idx: Union[slice, Tensor, ndarray, Sequence]) List[BaseData] [source]
Creates a subset of
Data
orHeteroData
objects from specified indicesidx
. Indicesidx
can be a slicing object, e.g.,[2:5]
, a list, a tuple, or atorch.Tensor
ornp.ndarray
of type long or bool. TheBatch
object must have been created viafrom_data_list()
in order to be able to reconstruct the initial objects.
- to_data_list() List[BaseData] [source]
Reconstructs the list of
Data
orHeteroData
objects from theBatch
object. TheBatch
object must have been created viafrom_data_list()
in order to be able to reconstruct the initial objects.
- property batch_size: int
Alias for
num_graphs
.