Source code for torch_geometric.utils._to_dense_batch

from typing import Optional, Tuple

import torch
from torch import Tensor

from torch_geometric.experimental import (
    disable_dynamic_shapes,
    is_experimental_mode_enabled,
)
from torch_geometric.utils import cumsum, scatter


[docs]@disable_dynamic_shapes(required_args=['batch_size', 'max_num_nodes']) def to_dense_batch( x: Tensor, batch: Optional[Tensor] = None, fill_value: float = 0.0, max_num_nodes: Optional[int] = None, batch_size: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: r"""Given a sparse batch of node features :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with :math:`N_i` indicating the number of nodes in graph :math:`i`), creates a dense node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with :math:`N_{\max} = \max_i^B N_i`). In addition, a mask of shape :math:`\mathbf{M} \in \{ 0, 1 \}^{B \times N_{\max}}` is returned, holding information about the existence of fake-nodes in the dense representation. Args: x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. batch (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Must be ordered. (default: :obj:`None`) fill_value (float, optional): The value for invalid entries in the resulting dense output tensor. (default: :obj:`0`) max_num_nodes (int, optional): The size of the output node dimension. (default: :obj:`None`) batch_size (int, optional): The batch size. (default: :obj:`None`) :rtype: (:class:`Tensor`, :class:`BoolTensor`) Examples: >>> x = torch.arange(12).view(6, 2) >>> x tensor([[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7], [ 8, 9], [10, 11]]) >>> out, mask = to_dense_batch(x) >>> mask tensor([[True, True, True, True, True, True]]) >>> batch = torch.tensor([0, 0, 1, 2, 2, 2]) >>> out, mask = to_dense_batch(x, batch) >>> out tensor([[[ 0, 1], [ 2, 3], [ 0, 0]], [[ 4, 5], [ 0, 0], [ 0, 0]], [[ 6, 7], [ 8, 9], [10, 11]]]) >>> mask tensor([[ True, True, False], [ True, False, False], [ True, True, True]]) >>> out, mask = to_dense_batch(x, batch, max_num_nodes=4) >>> out tensor([[[ 0, 1], [ 2, 3], [ 0, 0], [ 0, 0]], [[ 4, 5], [ 0, 0], [ 0, 0], [ 0, 0]], [[ 6, 7], [ 8, 9], [10, 11], [ 0, 0]]]) >>> mask tensor([[ True, True, False, False], [ True, False, False, False], [ True, True, True, False]]) """ if batch is None and max_num_nodes is None: mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device) return x.unsqueeze(0), mask if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) if batch_size is None: batch_size = int(batch.max()) + 1 num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0, dim_size=batch_size, reduce='sum') cum_nodes = cumsum(num_nodes) filter_nodes = False dynamic_shapes_disabled = is_experimental_mode_enabled( 'disable_dynamic_shapes') if max_num_nodes is None: max_num_nodes = int(num_nodes.max()) elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes: filter_nodes = True tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch] idx = tmp + (batch * max_num_nodes) if filter_nodes: mask = tmp < max_num_nodes x, idx = x[mask], idx[mask] size = [batch_size * max_num_nodes] + list(x.size())[1:] out = torch.as_tensor(fill_value, device=x.device) out = out.to(x.dtype).repeat(size) out[idx] = x out = out.view([batch_size, max_num_nodes] + list(x.size())[1:]) mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool, device=x.device) mask[idx] = 1 mask = mask.view(batch_size, max_num_nodes) return out, mask