Source code for torch_geometric.transforms.normalize_features

from typing import List, Union

from import Data, HeteroData
from import functional_transform
from torch_geometric.transforms import BaseTransform

[docs]@functional_transform('normalize_features') class NormalizeFeatures(BaseTransform): r"""Row-normalizes the attributes given in :obj:`attrs` to sum-up to one (functional name: :obj:`normalize_features`). Args: attrs (List[str]): The names of attributes to normalize. (default: :obj:`["x"]`) """ def __init__(self, attrs: List[str] = ["x"]): self.attrs = attrs def __call__(self, data: Union[Data, HeteroData]): for store in data.stores: for key, value in store.items(*self.attrs): value = value - value.min() value.div_(value.sum(dim=-1, keepdim=True).clamp_(min=1.)) store[key] = value return data def __repr__(self) -> str: return f'{self.__class__.__name__}()'