Source code for torch_geometric.transforms.add_gpse

from torch.nn import Module

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform, VirtualNode


[docs]@functional_transform('add_gpse') class AddGPSE(BaseTransform): r"""Adds the GPSE encoding from the `"Graph Positional and Structural Encoder" <https://arxiv.org/abs/2307.07107>`_ paper to the given graph (functional name: :obj:`add_gpse`). To be used with a :class:`~torch_geometric.nn.GPSE` model, which generates the actual encodings. Args: model (Module): The pre-trained GPSE model. use_vn (bool, optional): Whether to use virtual nodes. (default: :obj:`True`) rand_type (str, optional): Type of random features to use. Options are :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`. (default: :obj:`NormalSE`) """ def __init__(self, model: Module, use_vn: bool = True, rand_type: str = 'NormalSE'): self.model = model self.use_vn = use_vn self.vn = VirtualNode() self.rand_type = rand_type def __call__(self, data: Data) -> Data: from torch_geometric.nn.models.gpse import gpse_process data_vn = self.vn(data.clone()) if self.use_vn else data.clone() batch_out = gpse_process(self.model, data_vn, 'NormalSE', self.use_vn) batch_out = batch_out.to('cpu', non_blocking=True) data.pestat_GPSE = batch_out[:-1] if self.use_vn else batch_out return data