Source code for torch_geometric.utils.to_dense_batch

import torch
from torch_scatter import scatter_add


[docs]def to_dense_batch(x, batch=None, fill_value=0, max_num_nodes=None): r"""Given a sparse batch of node features :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with :math:`N_i` indicating the number of nodes in graph :math:`i`), creates a dense node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with :math:`N_{\max} = \max_i^B N_i`). In addition, a second tensor holding :math:`[N_1, \ldots, N_B] \in \mathbb{N}^B` is returned. Args: x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. batch (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) fill_value (float, optional): The value for invalid entries in the resulting dense output tensor. (default: :obj:`0`) max_num_nodes (int, optional): The size of the output node dimension. (default: :obj:`None`) :rtype: (:class:`Tensor`, :class:`BoolTensor`) """ if batch is None and max_num_nodes is None: mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device) return x.unsqueeze(0), mask if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) batch_size = batch[-1].item() + 1 num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0, dim_size=batch_size) cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) if max_num_nodes is None: max_num_nodes = num_nodes.max().item() idx = torch.arange(batch.size(0), dtype=torch.long, device=x.device) idx = (idx - cum_nodes[batch]) + (batch * max_num_nodes) size = [batch_size * max_num_nodes] + list(x.size())[1:] out = x.new_full(size, fill_value) out[idx] = x out = out.view([batch_size, max_num_nodes] + list(x.size())[1:]) mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool, device=x.device) mask[idx] = 1 mask = mask.view(batch_size, max_num_nodes) return out, mask