from typing import List
import torch
from torch import Tensor
from torch_geometric.utils import degree
[docs]def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]:
r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension
:obj:`dim`.
Args:
src (Tensor): The source tensor.
batch (LongTensor): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
entry in :obj:`src` to a specific example. Must be ordered.
dim (int, optional): The dimension along which to split the :obj:`src`
tensor. (default: :obj:`0`)
:rtype: :class:`List[Tensor]`
Example:
>>> src = torch.arange(7)
>>> batch = torch.tensor([0, 0, 0, 1, 1, 2, 2])
>>> unbatch(src, batch)
(tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))
"""
sizes = degree(batch, dtype=torch.long).tolist()
return src.split(sizes, dim)
[docs]def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector.
Args:
edge_index (Tensor): The edge_index tensor. Must be ordered.
batch (LongTensor): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. Must be ordered.
:rtype: :class:`List[Tensor]`
Example:
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6],
... [1, 0, 2, 1, 3, 2, 5, 4, 6, 5]])
>>> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1])
>>> unbatch_edge_index(edge_index, batch)
(tensor([[0, 1, 1, 2, 2, 3],
[1, 0, 2, 1, 3, 2]]),
tensor([[0, 1, 1, 2],
[1, 0, 2, 1]]))
"""
deg = degree(batch, dtype=torch.int64)
ptr = torch.cat([deg.new_zeros(1), deg.cumsum(dim=0)[:-1]], dim=0)
edge_batch = batch[edge_index[0]]
edge_index = edge_index - ptr[edge_batch]
sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist()
return edge_index.split(sizes, dim=1)