Source code for torch_geometric.transforms.to_undirected

from torch_geometric.utils import to_undirected


[docs]class ToUndirected(object): r"""Converts the graph to an undirected graph, so that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in \mathcal{E}`. Args: reduce (string, optional): The reduce operation to use for merging edge features. (default: :obj:`"add"`) """ def __init__(self, reduce: str = "add"): self.reduce = reduce def __call__(self, data): if 'edge_index' in data: if 'edge_attr' in data: data.edge_index, data.edge_attr = to_undirected( data.edge_index, data.edge_attr, num_nodes=data.num_nodes, reduce=self.reduce) else: data.edge_index = to_undirected(data.edge_index, num_nodes=data.num_nodes) if 'adj_t' in data: data.adj_t = data.adj_t.to_symmetric(self.reduce) return data def __repr__(self): return f'{self.__class__.__name__}()'