Source code for torch_geometric.transforms.center
from torch_geometric.transforms import BaseTransform
[docs]class Center(BaseTransform):
r"""Centers node positions around the origin."""
def __call__(self, data):
data.pos = data.pos - data.pos.mean(dim=-2, keepdim=True)
return data
def __repr__(self):
return '{}()'.format(self.__class__.__name__)