Source code for torch_geometric.transforms.to_undirected

from typing import Union

import torch
from torch import Tensor

from torch_geometric.utils import to_undirected
from import Data, HeteroData
from torch_geometric.transforms import BaseTransform

[docs]class ToUndirected(BaseTransform): r"""Converts a homogeneous or heterogeneous graph to an undirected graph such that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in \mathcal{E}`. In heterogeneous graphs, will add "reverse" connections for *all* existing edge types. Args: reduce (string, optional): The reduce operation to use for merging edge features (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"add"`) merge (bool, optional): If set to :obj:`False`, will create reverse edge types for connections pointing to the same source and target node type. If set to :obj:`True`, reverse edges will be merged into the original relation. This option only has effects in :class:`` graph data. (default: :obj:`True`) """ def __init__(self, reduce: str = "add", merge: bool = True): self.reduce = reduce self.merge = merge def __call__(self, data: Union[Data, HeteroData]): for store in data.edge_stores: if 'edge_index' not in store: continue nnz = store.edge_index.size(1) if isinstance(data, HeteroData) and (store.is_bipartite() or not self.merge): src, rel, dst = store._key # Just reverse the connectivity and add edge attributes: row, col = store.edge_index rev_edge_index = torch.stack([col, row], dim=0) inv_store = data[dst, f'rev_{rel}', src] inv_store.edge_index = rev_edge_index for key, value in store.items(): if key == 'edge_index': continue if isinstance(value, Tensor) and value.size(0) == nnz: inv_store[key] = value else: keys, values = [], [] for key, value in store.items(): if key == 'edge_index': continue if store.is_edge_attr(key): keys.append(key) values.append(value) store.edge_index, values = to_undirected( store.edge_index, values, reduce=self.reduce) for key, value in zip(keys, values): store[key] = value return data def __repr__(self) -> str: return f'{self.__class__.__name__}()'