Source code for torch_geometric.utils.to_dense_adj

from typing import Optional

import torch
from torch import Tensor
from torch_scatter import scatter

from torch_geometric.typing import OptTensor


[docs]def to_dense_adj( edge_index: Tensor, batch: OptTensor = None, edge_attr: OptTensor = None, max_num_nodes: Optional[int] = None, ) -> Tensor: r"""Converts batched sparse adjacency matrices given by edge indices and edge attributes to a single dense batched adjacency matrix. Args: edge_index (LongTensor): The edge indices. batch (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) max_num_nodes (int, optional): The size of the output node dimension. (default: :obj:`None`) :rtype: :class:`Tensor` Examples: >>> edge_index = torch.tensor([[0, 0, 1, 2, 3], ... [0, 1, 0, 3, 0]]) >>> batch = torch.tensor([0, 0, 1, 1]) >>> to_dense_adj(edge_index, batch) tensor([[[1., 1.], [1., 0.]], [[0., 1.], [1., 0.]]]) >>> to_dense_adj(edge_index, batch, max_num_nodes=4) tensor([[[1., 1., 0., 0.], [1., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 1., 0., 0.], [1., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]]) >>> edge_attr = torch.Tensor([1, 2, 3, 4, 5]) >>> to_dense_adj(edge_index, batch, edge_attr) tensor([[[1., 2.], [3., 0.]], [[0., 4.], [5., 0.]]]) """ if batch is None: num_nodes = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 batch = edge_index.new_zeros(num_nodes) batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1 one = batch.new_ones(batch.size(0)) num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='add') cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) idx0 = batch[edge_index[0]] idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]] idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]] if max_num_nodes is None: max_num_nodes = num_nodes.max().item() elif ((idx1.numel() > 0 and idx1.max() >= max_num_nodes) or (idx2.numel() > 0 and idx2.max() >= max_num_nodes)): mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes) idx0 = idx0[mask] idx1 = idx1[mask] idx2 = idx2[mask] edge_attr = None if edge_attr is None else edge_attr[mask] if edge_attr is None: edge_attr = torch.ones(idx0.numel(), device=edge_index.device) size = [batch_size, max_num_nodes, max_num_nodes] size += list(edge_attr.size())[1:] adj = torch.zeros(size, dtype=edge_attr.dtype, device=edge_index.device) flattened_size = batch_size * max_num_nodes * max_num_nodes adj = adj.view([flattened_size] + list(adj.size())[3:]) idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2 scatter(edge_attr, idx, dim=0, out=adj, reduce='add') adj = adj.view(size) return adj