Source code for torch_geometric.transforms.constant

from typing import List, Optional, Union

import torch

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform


[docs]@functional_transform('constant') class Constant(BaseTransform): r"""Appends a constant value to each node feature :obj:`x` (functional name: :obj:`constant`). Args: value (float, optional): The value to add. (default: :obj:`1.0`) cat (bool, optional): If set to :obj:`False`, existing node features will be replaced. (default: :obj:`True`) node_types (str or List[str], optional): The specified node type(s) to append constant values for if used on heterogeneous graphs. If set to :obj:`None`, constants will be added to each node feature :obj:`x` for all existing node types. (default: :obj:`None`) """ def __init__( self, value: float = 1.0, cat: bool = True, node_types: Optional[Union[str, List[str]]] = None, ): if isinstance(node_types, str): node_types = [node_types] self.value = value self.cat = cat self.node_types = node_types def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.node_stores: if self.node_types is None or store._key in self.node_types: num_nodes = store.num_nodes assert num_nodes is not None c = torch.full((num_nodes, 1), self.value, dtype=torch.float) if hasattr(store, 'x') and self.cat: x = store.x.view(-1, 1) if store.x.dim() == 1 else store.x store.x = torch.cat([x, c.to(x.device, x.dtype)], dim=-1) else: store.x = c return data def __repr__(self) -> str: return f'{self.__class__.__name__}(value={self.value})'