# Source code for torch_geometric.utils.to_dense_batch

from typing import Optional, Tuple

import torch
from torch import Tensor

[docs]def to_dense_batch(x: Tensor, batch: Optional[Tensor] = None,
fill_value: float = 0., max_num_nodes: Optional[int] = None,
batch_size: Optional[int] = None) -> Tuple[Tensor, Tensor]:
r"""Given a sparse batch of node features
:math:\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F} (with
:math:N_i indicating the number of nodes in graph :math:i), creates a
dense node feature tensor
:math:\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F} (with
:math:N_{\max} = \max_i^B N_i).
In addition, a mask of shape :math:\mathbf{M} \in \{ 0, 1 \}^{B \times
N_{\max}} is returned, holding information about the existence of
fake-nodes in the dense representation.

Args:
x (Tensor): Node feature matrix
:math:\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}.
batch (LongTensor, optional): Batch vector
:math:\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N, which assigns each
node to a specific example. Must be ordered. (default: :obj:None)
fill_value (float, optional): The value for invalid entries in the
resulting dense output tensor. (default: :obj:0)
max_num_nodes (int, optional): The size of the output node dimension.
(default: :obj:None)
batch_size (int, optional) The batch size. (default: :obj:None)

:rtype: (:class:Tensor, :class:BoolTensor)
"""
if batch is None and max_num_nodes is None:
mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device)

if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)

if batch_size is None:
batch_size = int(batch.max()) + 1

dim_size=batch_size)
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])

if max_num_nodes is None:
max_num_nodes = int(num_nodes.max())

idx = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
idx = (idx - cum_nodes[batch]) + (batch * max_num_nodes)

size = [batch_size * max_num_nodes] + list(x.size())[1:]
out = x.new_full(size, fill_value)
out[idx] = x
out = out.view([batch_size, max_num_nodes] + list(x.size())[1:])

mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool,
device=x.device)