torch_geometric.transforms.LinearTransformation

class LinearTransformation(matrix: Tensor)[source]

Bases: BaseTransform

Transforms node positions data.pos with a square transformation matrix computed offline (functional name: linear_transformation).

Parameters:

matrix (Tensor) – Tensor with shape [D, D] where D corresponds to the dimensionality of node positions.