Source code for torch_geometric.transforms.normalize_scale

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform, Center


[docs]@functional_transform('normalize_scale') class NormalizeScale(BaseTransform): r"""Centers and normalizes node positions to the interval :math:`(-1, 1)` (functional name: :obj:`normalize_scale`). """ def __init__(self) -> None: self.center = Center() def forward(self, data: Data) -> Data: data = self.center(data) assert data.pos is not None scale = (1.0 / data.pos.abs().max()) * 0.999999 data.pos = data.pos * scale return data