Source code for torch_geometric.utils.nested

from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor

from torch_geometric.utils import scatter

[docs]def to_nested_tensor( x: Tensor, batch: Optional[Tensor] = None, ptr: Optional[Tensor] = None, batch_size: Optional[int] = None, ) -> torch.Tensor: r"""Given a contiguous batch of tensors :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}` (with :math:`N_i` indicating the number of elements in example :math:`i`), creates a `nested PyTorch tensor <>`__. Reverse operation of :meth:`from_nested_tensor`. Args: x (torch.Tensor): The input tensor :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}`. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. Must be ordered. (default: :obj:`None`) ptr (torch.Tensor, optional): Alternative representation of :obj:`batch` in compressed format. (default: :obj:`None`) batch_size (int, optional) The batch size :math:`B`. (default: :obj:`None`) """ if ptr is not None: offsets = ptr[1:] - ptr[:-1] sizes: List[int] = offsets.tolist() xs = list(torch.split(x, sizes, dim=0)) elif batch is not None: offsets = scatter(torch.ones_like(batch), batch, dim_size=batch_size) sizes: List[int] = offsets.tolist() xs = list(torch.split(x, sizes, dim=0)) else: xs = [x] # This currently copies the data, although `x` is already contiguous. # Sadly, there does not exist any (public) API to preven this :( return torch.nested.as_nested_tensor(xs)
[docs]def from_nested_tensor( x: Tensor, return_batch: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: r"""Given a `nested PyTorch tensor <>`__, creates a contiguous batch of tensors :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times *}`, and optionally a batch vector which assigns each element to a specific example. Reverse operation of :meth:`to_nested_tensor`. Args: x (torch.Tensor): The nested input tensor. The size of nested tensors need to match except for the first dimension. return_batch (bool, optional): If set to :obj:`True`, will also return the batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`. (default: :obj:`False`) """ if not x.is_nested: raise ValueError("Input tensor in 'from_nested_tensor' is not nested") sizes = x._nested_tensor_size() for dim, (a, b) in enumerate(zip(sizes[0, 1:], sizes.t()[1:])): if not torch.equal(a.expand_as(b), b): raise ValueError(f"Not all nested tensors have the same size " f"in dimension {dim + 1} " f"(expected size {a.item()} for all tensors)") out = x.contiguous().values() out = out.view(-1, *sizes[0, 1:].tolist()) if not return_batch: return out batch = torch.arange(x.size(0), device=x.device) batch = batch.repeat_interleave(sizes[:, 0].to(batch.device)) return out, batch