Source code for torch_geometric.transforms.center

from typing import Union

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform


[docs]@functional_transform('center') class Center(BaseTransform): r"""Centers node positions :obj:`pos` around the origin (functional name: :obj:`center`).""" def __call__( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.node_stores: if hasattr(store, 'pos'): store.pos = store.pos - store.pos.mean(dim=-2, keepdim=True) return data