Source code for torch_geometric.transforms.virtual_node

import copy

import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform


[docs]@functional_transform('virtual_node') class VirtualNode(BaseTransform): r"""Appends a virtual node to the given homogeneous graph that is connected to all other nodes, as described in the `"Neural Message Passing for Quantum Chemistry" <https://arxiv.org/abs/1704.01212>`_ paper (functional name: :obj:`virtual_node`). The virtual node serves as a global scratch space that each node both reads from and writes to in every step of message passing. This allows information to travel long distances during the propagation phase. Node and edge features of the virtual node are added as zero-filled input features. Furthermore, special edge types will be added both for in-coming and out-going information to and from the virtual node. """ def forward(self, data: Data) -> Data: assert data.edge_index is not None row, col = data.edge_index edge_type = data.get('edge_type', torch.zeros_like(row)) num_nodes = data.num_nodes assert num_nodes is not None arange = torch.arange(num_nodes, device=row.device) full = row.new_full((num_nodes, ), num_nodes) row = torch.cat([row, arange, full], dim=0) col = torch.cat([col, full, arange], dim=0) edge_index = torch.stack([row, col], dim=0) new_type = edge_type.new_full((num_nodes, ), int(edge_type.max()) + 1) edge_type = torch.cat([edge_type, new_type, new_type + 1], dim=0) old_data = copy.copy(data) for key, value in old_data.items(): if key == 'edge_index' or key == 'edge_type': continue if isinstance(value, Tensor): dim = old_data.__cat_dim__(key, value) size = list(value.size()) fill_value = None if key == 'edge_weight': size[dim] = 2 * num_nodes fill_value = 1. elif key == 'batch': size[dim] = 1 fill_value = int(value[0]) elif old_data.is_edge_attr(key): size[dim] = 2 * num_nodes fill_value = 0. elif old_data.is_node_attr(key): size[dim] = 1 fill_value = 0. if fill_value is not None: new_value = value.new_full(size, fill_value) data[key] = torch.cat([value, new_value], dim=dim) data.edge_index = edge_index data.edge_type = edge_type if 'num_nodes' in data: data.num_nodes = num_nodes + 1 return data