Source code for torch_geometric.transforms.random_shear

import torch

from torch_geometric.transforms import BaseTransform, LinearTransformation


[docs]class RandomShear(BaseTransform): r"""Shears node positions by randomly sampled factors :math:`s` within a given interval, *e.g.*, resulting in the transformation matrix .. math:: \begin{bmatrix} 1 & s_{xy} & s_{xz} \\ s_{yx} & 1 & s_{yz} \\ s_{zx} & z_{zy} & 1 \\ \end{bmatrix} for three-dimensional positions. Args: shear (float or int): maximum shearing factor defining the range :math:`(-\mathrm{shear}, +\mathrm{shear})` to sample from. """ def __init__(self, shear): self.shear = abs(shear) def __call__(self, data): dim = data.pos.size(-1) matrix = data.pos.new_empty(dim, dim).uniform_(-self.shear, self.shear) eye = torch.arange(dim, dtype=torch.long) matrix[eye, eye] = 1 return LinearTransformation(matrix)(data) def __repr__(self): return '{}({})'.format(self.__class__.__name__, self.shear)