torch_geometric.data.Batch

class Batch(*args: Any, **kwargs: Any)[source]

Bases: object

A data object describing a batch of graphs as one big (disconnected) graph. Inherits from torch_geometric.data.Data or torch_geometric.data.HeteroData. In addition, single graphs can be identified via the assignment vector batch, which maps each node to its respective graph identifier.

allows modification to the underlying batching procedure by overwriting the __inc__() and __cat_dim__() functionalities. The __inc__() method defines the incremental count between two consecutive graph attributes. By default, increments attributes by the number of nodes whenever their attribute names contain the substring index (for historical reasons), which comes in handy for attributes such as edge_index or node_index. However, note that this may lead to unexpected behavior for attributes whose names contain the substring index but should not be incremented. To make sure, it is best practice to always double-check the output of batching. Furthermore, __cat_dim__() defines in which dimension graph tensors of the same attribute should be concatenated together.

classmethod from_data_list(data_list: List[BaseData], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None) Self[source]

Constructs a Batch object from a list of Data or HeteroData objects. The assignment vector batch is created on the fly. In addition, creates assignment vectors for each key in follow_batch. Will exclude any keys given in exclude_keys.

get_example(idx: int) BaseData[source]

Gets the Data or HeteroData object at index idx. The Batch object must have been created via from_data_list() in order to be able to reconstruct the initial object.

index_select(idx: Union[slice, Tensor, ndarray, Sequence]) List[BaseData][source]

Creates a subset of Data or HeteroData objects from specified indices idx. Indices idx can be a slicing object, e.g., [2:5], a list, a tuple, or a torch.Tensor or np.ndarray of type long or bool. The Batch object must have been created via from_data_list() in order to be able to reconstruct the initial objects.

to_data_list() List[BaseData][source]

Reconstructs the list of Data or HeteroData objects from the Batch object. The Batch object must have been created via from_data_list() in order to be able to reconstruct the initial objects.

property num_graphs: int

Returns the number of graphs in the batch.

property batch_size: int

Alias for num_graphs.