Source code for torch_geometric.transforms.linear_transformation
import torch
from torch_geometric.transforms import BaseTransform
[docs]class LinearTransformation(BaseTransform):
r"""Transforms node positions with a square transformation matrix computed
offline.
Args:
matrix (Tensor): tensor with shape :obj:`[D, D]` where :obj:`D`
corresponds to the dimensionality of node positions.
"""
def __init__(self, matrix):
assert matrix.dim() == 2, (
'Transformation matrix should be two-dimensional.')
assert matrix.size(0) == matrix.size(1), (
'Transformation matrix should be square. Got [{} x {}] rectangular'
'matrix.'.format(*matrix.size()))
# Store the matrix as its transpose.
# We do this to enable post-multiplication in `__call__`.
self.matrix = matrix.t()
def __call__(self, data):
pos = data.pos.view(-1, 1) if data.pos.dim() == 1 else data.pos
assert pos.size(-1) == self.matrix.size(-2), (
'Node position matrix and transformation matrix have incompatible '
'shape.')
# We post-multiply the points by the transformation matrix instead of
# pre-multiplying, because `data.pos` has shape `[N, D]`, and we want
# to preserve this shape.
data.pos = torch.matmul(pos, self.matrix.to(pos.dtype).to(pos.device))
return data
def __repr__(self):
return '{}({})'.format(self.__class__.__name__, self.matrix.tolist())