Source code for torch_geometric.transforms.random_jitter

from itertools import repeat
from typing import Sequence, Union

from import Data
from import functional_transform
from torch_geometric.transforms import BaseTransform

[docs]@functional_transform('random_jitter') class RandomJitter(BaseTransform): r"""Translates node positions by randomly sampled translation values within a given interval (functional name: :obj:`random_jitter`). In contrast to other random transformations, translation is applied separately at each position. Args: translate (sequence or float or int): Maximum translation in each dimension, defining the range :math:`(-\mathrm{translate}, +\mathrm{translate})` to sample from. If :obj:`translate` is a number instead of a sequence, the same range is used for each dimension. """ def __init__( self, translate: Union[float, int, Sequence[Union[float, int]]], ) -> None: self.translate = translate def forward(self, data: Data) -> Data: assert data.pos is not None num_nodes, dim = data.pos.size() translate: Sequence[Union[float, int]] if isinstance(self.translate, (int, float)): translate = list(repeat(self.translate, times=dim)) else: assert len(self.translate) == dim translate = self.translate jitter = data.pos.new_empty(num_nodes, dim) for d in range(dim): jitter[:, d].uniform_(-abs(translate[d]), abs(translate[d])) data.pos = data.pos + jitter return data def __repr__(self) -> str: return f'{self.__class__.__name__}({self.translate})'