Source code for torch_geometric.transforms.center
from typing import Union
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import BaseTransform
[docs]class Center(BaseTransform):
r"""Centers node positions :obj:`pos` around the origin."""
def __call__(self, data: Union[Data, HeteroData]):
for store in data.node_stores:
if hasattr(store, 'pos'):
store.pos = store.pos - store.pos.mean(dim=-2, keepdim=True)
return data