Source code for torch_geometric.data.batch

from typing import List

import torch
from torch import Tensor
from torch_sparse import SparseTensor, cat
import torch_geometric
from torch_geometric.data import Data


[docs]class Batch(Data): r"""A plain old python object modeling a batch of graphs as one big (disconnected) graph. With :class:`torch_geometric.data.Data` being the base class, all its methods can also be used here. In addition, single graphs can be reconstructed via the assignment vector :obj:`batch`, which maps each node to its respective graph identifier. """ def __init__(self, batch=None, ptr=None, **kwargs): super(Batch, self).__init__(**kwargs) for key, item in kwargs.items(): if key == 'num_nodes': self.__num_nodes__ = item else: self[key] = item self.batch = batch self.ptr = ptr self.__data_class__ = Data self.__slices__ = None self.__cumsum__ = None self.__cat_dims__ = None self.__num_nodes_list__ = None self.__num_graphs__ = None
[docs] @classmethod def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`.""" keys = list(set(data_list[0].keys) - set(exclude_keys)) assert 'batch' not in keys and 'ptr' not in keys batch = cls() for key in data_list[0].__dict__.keys(): if key[:2] != '__' and key[-2:] != '__': batch[key] = None batch.__num_graphs__ = len(data_list) batch.__data_class__ = data_list[0].__class__ for key in keys + ['batch']: batch[key] = [] batch['ptr'] = [0] device = None slices = {key: [0] for key in keys} cumsum = {key: [0] for key in keys} cat_dims = {} num_nodes_list = [] for i, data in enumerate(data_list): for key in keys: item = data[key] # Increase values by `cumsum` value. cum = cumsum[key][-1] if isinstance(item, Tensor) and item.dtype != torch.bool: if not isinstance(cum, int) or cum != 0: item = item + cum elif isinstance(item, SparseTensor): value = item.storage.value() if value is not None and value.dtype != torch.bool: if not isinstance(cum, int) or cum != 0: value = value + cum item = item.set_value(value, layout='coo') elif isinstance(item, (int, float)): item = item + cum # Gather the size of the `cat` dimension. size = 1 cat_dim = data.__cat_dim__(key, data[key]) # 0-dimensional tensors have no dimension along which to # concatenate, so we set `cat_dim` to `None`. if isinstance(item, Tensor) and item.dim() == 0: cat_dim = None cat_dims[key] = cat_dim # Add a batch dimension to items whose `cat_dim` is `None`: if isinstance(item, Tensor) and cat_dim is None: cat_dim = 0 # Concatenate along this new batch dimension. item = item.unsqueeze(0) device = item.device elif isinstance(item, Tensor): size = item.size(cat_dim) device = item.device elif isinstance(item, SparseTensor): size = torch.tensor(item.sizes())[torch.tensor(cat_dim)] device = item.device() batch[key].append(item) # Append item to the attribute list. slices[key].append(size + slices[key][-1]) inc = data.__inc__(key, item) if isinstance(inc, (tuple, list)): inc = torch.tensor(inc) cumsum[key].append(inc + cumsum[key][-1]) if key in follow_batch: if isinstance(size, Tensor): for j, size in enumerate(size.tolist()): tmp = f'{key}_{j}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long, device=device)) else: tmp = f'{key}_batch' batch[tmp] = [] if i == 0 else batch[tmp] batch[tmp].append( torch.full((size, ), i, dtype=torch.long, device=device)) if hasattr(data, '__num_nodes__'): num_nodes_list.append(data.__num_nodes__) else: num_nodes_list.append(None) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long, device=device) batch.batch.append(item) batch.ptr.append(batch.ptr[-1] + num_nodes) batch.batch = None if len(batch.batch) == 0 else batch.batch batch.ptr = None if len(batch.ptr) == 1 else batch.ptr batch.__slices__ = slices batch.__cumsum__ = cumsum batch.__cat_dims__ = cat_dims batch.__num_nodes_list__ = num_nodes_list ref_data = data_list[0] for key in batch.keys: items = batch[key] item = items[0] cat_dim = ref_data.__cat_dim__(key, item) cat_dim = 0 if cat_dim is None else cat_dim if isinstance(item, Tensor): batch[key] = torch.cat(items, cat_dim) elif isinstance(item, SparseTensor): batch[key] = cat(items, cat_dim) elif isinstance(item, (int, float)): batch[key] = torch.tensor(items) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
[docs] def get_example(self, idx: int) -> Data: r"""Reconstructs the :class:`torch_geometric.data.Data` object at index :obj:`idx` from the batch object. The batch object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial objects.""" if self.__slices__ is None: raise RuntimeError( ('Cannot reconstruct data list from batch because the batch ' 'object was not created using `Batch.from_data_list()`.')) data = self.__data_class__() for key in self.__slices__.keys(): item = self[key] if self.__cat_dims__[key] is None: # The item was concatenated along a new batch dimension, # so just index in that dimension: item = item[idx] else: # Narrow the item based on the values in `__slices__`. if isinstance(item, Tensor): dim = self.__cat_dims__[key] start = self.__slices__[key][idx] end = self.__slices__[key][idx + 1] item = item.narrow(dim, start, end - start) elif isinstance(item, SparseTensor): for j, dim in enumerate(self.__cat_dims__[key]): start = self.__slices__[key][idx][j].item() end = self.__slices__[key][idx + 1][j].item() item = item.narrow(dim, start, end - start) else: start = self.__slices__[key][idx] end = self.__slices__[key][idx + 1] item = item[start:end] item = item[0] if len(item) == 1 else item # Decrease its value by `cumsum` value: cum = self.__cumsum__[key][idx] if isinstance(item, Tensor): if not isinstance(cum, int) or cum != 0: item = item - cum elif isinstance(item, SparseTensor): value = item.storage.value() if value is not None and value.dtype != torch.bool: if not isinstance(cum, int) or cum != 0: value = value - cum item = item.set_value(value, layout='coo') elif isinstance(item, (int, float)): item = item - cum data[key] = item if self.__num_nodes_list__[idx] is not None: data.num_nodes = self.__num_nodes_list__[idx] return data
def index_select(self, idx: Tensor) -> List[Data]: if isinstance(idx, slice): idx = list(range(self.num_graphs)[idx]) elif torch.is_tensor(idx): if idx.dtype == torch.bool: idx = idx.nonzero(as_tuple=False).view(-1) idx = idx.tolist() elif isinstance(idx, list) or isinstance(idx, tuple): pass else: raise IndexError( 'Only integers, slices (`:`), list, tuples, and long or bool ' 'tensors are valid indices (got {}).'.format( type(idx).__name__)) return [self.get_example(i) for i in idx]
[docs] def __getitem__(self, idx): if isinstance(idx, str): return super(Batch, self).__getitem__(idx) elif isinstance(idx, int): return self.get_example(idx) else: return self.index_select(idx)
[docs] def to_data_list(self) -> List[Data]: r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects from the batch object. The batch object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial objects.""" return [self.get_example(i) for i in range(self.num_graphs)]
@property def num_graphs(self) -> int: """Returns the number of graphs in the batch.""" if self.__num_graphs__ is not None: return self.__num_graphs__ elif self.ptr is not None: return self.ptr.numel() + 1 elif self.batch is not None: return int(self.batch.max()) + 1 else: raise ValueError