torch_geometric.data.Batch

class Batch(*args, **kwargs)[source]

Bases: object

A data object describing a batch of graphs as one big (disconnected) graph. Inherits from torch_geometric.data.Data or torch_geometric.data.HeteroData. In addition, single graphs can be identified via the assignment vector batch, which maps each node to its respective graph identifier.

classmethod from_data_list(data_list: List[BaseData], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None)[source]

Constructs a Batch object from a Python list of Data or HeteroData objects. The assignment vector batch is created on the fly. In addition, creates assignment vectors for each key in follow_batch. Will exclude any keys given in exclude_keys.

get_example(idx: int) BaseData[source]

Gets the Data or HeteroData object at index idx. The Batch object must have been created via from_data_list() in order to be able to reconstruct the initial object.

index_select(idx: Union[slice, Tensor, ndarray, Sequence]) List[BaseData][source]

Creates a subset of Data or HeteroData objects from specified indices idx. Indices idx can be a slicing object, e.g., [2:5], a list, a tuple, or a torch.Tensor or np.ndarray of type long or bool. The Batch object must have been created via from_data_list() in order to be able to reconstruct the initial objects.

to_data_list() List[BaseData][source]

Reconstructs the list of Data or HeteroData objects from the Batch object. The Batch object must have been created via from_data_list() in order to be able to reconstruct the initial objects.

property num_graphs: int

Returns the number of graphs in the batch.