Source code for torch_geometric.utils.unbatch

from typing import List

import torch
from torch import Tensor

from torch_geometric.utils import degree


[docs]def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]: r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension :obj:`dim`. Args: src (Tensor): The source tensor. batch (LongTensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each entry in :obj:`src` to a specific example. Must be ordered. dim (int, optional): The dimension along which to split the :obj:`src` tensor. (default: :obj:`0`) :rtype: :class:`List[Tensor]` """ sizes = degree(batch, dtype=torch.long).tolist() return src.split(sizes, dim)
[docs]def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]: r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector. Args: edge_index (Tensor): The edge_index tensor. Must be ordered. batch (LongTensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Must be ordered. :rtype: :class:`List[Tensor]` """ deg = degree(batch, dtype=torch.int64) ptr = torch.cat([deg.new_zeros(1), deg.cumsum(dim=0)[:-1]], dim=0) edge_batch = batch[edge_index[0]] edge_index = edge_index - ptr[edge_batch] sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist() return edge_index.split(sizes, dim=1)