from typing import Optional, Tuple
import torch
from torch import Tensor
from torch_geometric.experimental import (
disable_dynamic_shapes,
is_experimental_mode_enabled,
)
from torch_geometric.utils import cumsum, scatter
[docs]@disable_dynamic_shapes(required_args=['batch_size', 'max_num_nodes'])
def to_dense_batch(
x: Tensor,
batch: Optional[Tensor] = None,
fill_value: float = 0.0,
max_num_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
r"""Given a sparse batch of node features
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with
:math:`N_i` indicating the number of nodes in graph :math:`i`), creates a
dense node feature tensor
:math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with
:math:`N_{\max} = \max_i^B N_i`).
In addition, a mask of shape :math:`\mathbf{M} \in \{ 0, 1 \}^{B \times
N_{\max}}` is returned, holding information about the existence of
fake-nodes in the dense representation.
Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. Must be ordered. (default: :obj:`None`)
fill_value (float, optional): The value for invalid entries in the
resulting dense output tensor. (default: :obj:`0`)
max_num_nodes (int, optional): The size of the output node dimension.
(default: :obj:`None`)
batch_size (int, optional): The batch size. (default: :obj:`None`)
:rtype: (:class:`Tensor`, :class:`BoolTensor`)
Examples:
>>> x = torch.arange(12).view(6, 2)
>>> x
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11]])
>>> out, mask = to_dense_batch(x)
>>> mask
tensor([[True, True, True, True, True, True]])
>>> batch = torch.tensor([0, 0, 1, 2, 2, 2])
>>> out, mask = to_dense_batch(x, batch)
>>> out
tensor([[[ 0, 1],
[ 2, 3],
[ 0, 0]],
[[ 4, 5],
[ 0, 0],
[ 0, 0]],
[[ 6, 7],
[ 8, 9],
[10, 11]]])
>>> mask
tensor([[ True, True, False],
[ True, False, False],
[ True, True, True]])
>>> out, mask = to_dense_batch(x, batch, max_num_nodes=4)
>>> out
tensor([[[ 0, 1],
[ 2, 3],
[ 0, 0],
[ 0, 0]],
[[ 4, 5],
[ 0, 0],
[ 0, 0],
[ 0, 0]],
[[ 6, 7],
[ 8, 9],
[10, 11],
[ 0, 0]]])
>>> mask
tensor([[ True, True, False, False],
[ True, False, False, False],
[ True, True, True, False]])
"""
if batch is None and max_num_nodes is None:
mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device)
return x.unsqueeze(0), mask
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)
if batch_size is None:
batch_size = int(batch.max()) + 1
num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0,
dim_size=batch_size, reduce='sum')
cum_nodes = cumsum(num_nodes)
filter_nodes = False
dynamic_shapes_disabled = is_experimental_mode_enabled(
'disable_dynamic_shapes')
if max_num_nodes is None:
max_num_nodes = int(num_nodes.max())
elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes:
filter_nodes = True
tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch]
idx = tmp + (batch * max_num_nodes)
if filter_nodes:
mask = tmp < max_num_nodes
x, idx = x[mask], idx[mask]
size = [batch_size * max_num_nodes] + list(x.size())[1:]
out = torch.as_tensor(fill_value, device=x.device)
out = out.to(x.dtype).repeat(size)
out[idx] = x
out = out.view([batch_size, max_num_nodes] + list(x.size())[1:])
mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool,
device=x.device)
mask[idx] = 1
mask = mask.view(batch_size, max_num_nodes)
return out, mask