Source code for torch_geometric.transforms.random_scale

import random


[docs]class RandomScale(object): r"""Scales node positions by a randomly sampled factor :math:`s` within a given interval, *e.g.*, resulting in the transformation matrix .. math:: \begin{bmatrix} s & 0 & 0 \\ 0 & s & 0 \\ 0 & 0 & s \\ \end{bmatrix} for three-dimensional positions. Args: scales (tuple): scaling factor interval, e.g. :obj:`(a, b)`, then scale is randomly sampled from the range :math:`a \leq \mathrm{scale} \leq b`. """ def __init__(self, scales): assert isinstance(scales, (tuple, list)) and len(scales) == 2 self.scales = scales def __call__(self, data): scale = random.uniform(*self.scales) data.pos = data.pos * scale return data def __repr__(self): return '{}({})'.format(self.__class__.__name__, self.scales)