Source code for torch_geometric.transforms.to_dense

from typing import Optional

import torch
from torch import Tensor

from import Data
from import functional_transform
from torch_geometric.transforms import BaseTransform

[docs]@functional_transform('to_dense') class ToDense(BaseTransform): r"""Converts a sparse adjacency matrix to a dense adjacency matrix with shape :obj:`[num_nodes, num_nodes, *]` (functional name: :obj:`to_dense`). Args: num_nodes (int, optional): The number of nodes. If set to :obj:`None`, the number of nodes will get automatically inferred. (default: :obj:`None`) """ def __init__(self, num_nodes: Optional[int] = None) -> None: self.num_nodes = num_nodes def forward(self, data: Data) -> Data: assert data.edge_index is not None orig_num_nodes = data.num_nodes assert orig_num_nodes is not None if self.num_nodes is None: num_nodes = orig_num_nodes else: assert orig_num_nodes <= self.num_nodes num_nodes = self.num_nodes if data.edge_attr is None: edge_attr = torch.ones(data.edge_index.size(1), dtype=torch.float) else: edge_attr = data.edge_attr size = torch.Size([num_nodes, num_nodes] + list(edge_attr.size())[1:]) adj = torch.sparse_coo_tensor(data.edge_index, edge_attr, size) data.adj = adj.to_dense() data.edge_index = None data.edge_attr = None data.mask = torch.zeros(num_nodes, dtype=torch.bool) data.mask[:orig_num_nodes] = 1 if data.x is not None: _size = [num_nodes - data.x.size(0)] + list(data.x.size())[1:] data.x =[data.x, data.x.new_zeros(_size)], dim=0) if data.pos is not None: _size = [num_nodes - data.pos.size(0)] + list(data.pos.size())[1:] data.pos =[data.pos, data.pos.new_zeros(_size)], dim=0) if (data.y is not None and isinstance(data.y, Tensor) and data.y.size(0) == orig_num_nodes): _size = [num_nodes - data.y.size(0)] + list(data.y.size())[1:] data.y =[data.y, data.y.new_zeros(_size)], dim=0) return data def __repr__(self) -> str: if self.num_nodes is None: return f'{self.__class__.__name__}()' return f'{self.__class__.__name__}(num_nodes={self.num_nodes})'