Source code for torch_geometric.data.batch

import torch
from torch import Tensor
from torch_sparse import SparseTensor, cat
import torch_geometric
from torch_geometric.data import Data


[docs]class Batch(Data): r"""A plain old python object modeling a batch of graphs as one big (disconnected) graph. With :class:`torch_geometric.data.Data` being the base class, all its methods can also be used here. In addition, single graphs can be reconstructed via the assignment vector :obj:`batch`, which maps each node to its respective graph identifier. """ def __init__(self, batch=None, **kwargs): super(Batch, self).__init__(**kwargs) self.batch = batch self.__data_class__ = Data self.__slices__ = None self.__cumsum__ = None self.__cat_dims__ = None self.__num_nodes_list__ = None
[docs] @staticmethod def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ for key in keys + ['batch']: batch[key] = [] slices = {key: [0] for key in keys} cumsum = {key: [0] for key in keys} cat_dims = {} num_nodes_list = [] for i, data in enumerate(data_list): for key in keys: item = data[key] # Increase values by `cumsum` value. cum = cumsum[key][-1] if isinstance(item, Tensor) and item.dtype != torch.bool: item = item + cum if cum != 0 else item elif isinstance(item, SparseTensor): value = item.storage.value() if value is not None and value.dtype != torch.bool: value = value + cum if cum != 0 else value item = item.set_value(value, layout='coo') elif isinstance(item, (int, float)): item = item + cum # Treat 0-dimensional tensors as 1-dimensional. if isinstance(item, Tensor) and item.dim() == 0: item = item.unsqueeze(0) batch[key].append(item) # Gather the size of the `cat` dimension. size = 1 cat_dim = data.__cat_dim__(key, data[key]) cat_dims[key] = cat_dim if isinstance(item, Tensor): size = item.size(cat_dim) elif isinstance(item, SparseTensor): size = torch.tensor(item.sizes())[torch.tensor(cat_dim)] slices[key].append(size + slices[key][-1]) inc = data.__inc__(key, item) if isinstance(inc, (tuple, list)): inc = torch.tensor(inc) cumsum[key].append(inc + cumsum[key][-1]) if key in follow_batch: if isinstance(size, Tensor): for j, size in enumerate(size.tolist()): tmp = f'{key}_{j}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long)) else: tmp = f'{key}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long)) if hasattr(data, '__num_nodes__'): num_nodes_list.append(data.__num_nodes__) else: num_nodes_list.append(None) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) # Fix initial slice values: for key in keys: slices[key][0] = slices[key][1] - slices[key][1] batch.batch = None if len(batch.batch) == 0 else batch.batch batch.__slices__ = slices batch.__cumsum__ = cumsum batch.__cat_dims__ = cat_dims batch.__num_nodes_list__ = num_nodes_list ref_data = data_list[0] for key in batch.keys: items = batch[key] item = items[0] if isinstance(item, Tensor): batch[key] = torch.cat(items, ref_data.__cat_dim__(key, item)) elif isinstance(item, SparseTensor): batch[key] = cat(items, ref_data.__cat_dim__(key, item)) elif isinstance(item, (int, float)): batch[key] = torch.tensor(items) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
[docs] def to_data_list(self): r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects from the batch object. The batch object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial objects.""" if self.__slices__ is None: raise RuntimeError( ('Cannot reconstruct data list from batch because the batch ' 'object was not created using `Batch.from_data_list()`.')) data_list = [] for i in range(len(list(self.__slices__.values())[0]) - 1): data = self.__data_class__() for key in self.__slices__.keys(): item = self[key] # Narrow the item based on the values in `__slices__`. if isinstance(item, Tensor): dim = self.__cat_dims__[key] start = self.__slices__[key][i] end = self.__slices__[key][i + 1] item = item.narrow(dim, start, end - start) elif isinstance(item, SparseTensor): for j, dim in enumerate(self.__cat_dims__[key]): start = self.__slices__[key][i][j].item() end = self.__slices__[key][i + 1][j].item() item = item.narrow(dim, start, end - start) else: item = item[self.__slices__[key][i]:self. __slices__[key][i + 1]] item = item[0] if len(item) == 1 else item # Decrease its value by `cumsum` value: cum = self.__cumsum__[key][i] if isinstance(item, Tensor): item = item - cum if cum != 0 else item elif isinstance(item, SparseTensor): value = item.storage.value() if value is not None and value.dtype != torch.bool: value = value - cum if cum != 0 else value item = item.set_value(value, layout='coo') elif isinstance(item, (int, float)): item = item - cum data[key] = item if self.__num_nodes_list__[i] is not None: data.num_nodes = self.__num_nodes_list__[i] data_list.append(data) return data_list
@property def num_graphs(self): """Returns the number of graphs in the batch.""" return self.batch[-1].item() + 1