Source code for torch_geometric.transforms.add_self_loops

from typing import Union

from torch import Tensor

from import Data, HeteroData
from import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import add_self_loops

[docs]@functional_transform('add_self_loops') class AddSelfLoops(BaseTransform): r"""Adds self-loops to the given homogeneous or heterogeneous graph (functional name: :obj:`add_self_loops`). Args: attr (str, optional): The name of the attribute of edge weights or multi-dimensional edge features to pass to :meth:`torch_geometric.utils.add_self_loops`. (default: :obj:`"edge_weight"`) fill_value (float or Tensor or str, optional): The way to generate edge features of self-loops (in case :obj:`attr != None`). If given as :obj:`float` or :class:`torch.Tensor`, edge features of self-loops will be directly given by :obj:`fill_value`. If given as :obj:`str`, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`1.`) """ def __init__( self, attr: str = 'edge_weight', fill_value: Union[float, Tensor, str] = 1.0, ) -> None: self.attr = attr self.fill_value = fill_value def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.edge_stores: if store.is_bipartite() or 'edge_index' not in store: continue store.edge_index, store[self.attr] = add_self_loops( store.edge_index, edge_attr=store.get(self.attr, None), fill_value=self.fill_value, num_nodes=store.size(0), ) return data