Source code for torch_geometric.utils.convert

from typing import Optional, Union, Tuple, List

from collections import defaultdict

import torch
import scipy.sparse
from torch import Tensor
from torch.utils.dlpack import to_dlpack, from_dlpack

import torch_geometric.data

from .num_nodes import maybe_num_nodes


[docs]def to_scipy_sparse_matrix(edge_index, edge_attr=None, num_nodes=None): r"""Converts a graph given by edge indices and edge attributes to a scipy sparse matrix. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) """ row, col = edge_index.cpu() if edge_attr is None: edge_attr = torch.ones(row.size(0)) else: edge_attr = edge_attr.view(-1).cpu() assert edge_attr.size(0) == row.size(0) N = maybe_num_nodes(edge_index, num_nodes) out = scipy.sparse.coo_matrix( (edge_attr.numpy(), (row.numpy(), col.numpy())), (N, N)) return out
[docs]def from_scipy_sparse_matrix(A): r"""Converts a scipy sparse matrix to edge indices and edge attributes. Args: A (scipy.sparse): A sparse matrix. """ A = A.tocoo() row = torch.from_numpy(A.row).to(torch.long) col = torch.from_numpy(A.col).to(torch.long) edge_index = torch.stack([row, col], dim=0) edge_weight = torch.from_numpy(A.data) return edge_index, edge_weight
[docs]def to_networkx(data, node_attrs=None, edge_attrs=None, to_undirected=False, remove_self_loops=False): r"""Converts a :class:`torch_geometric.data.Data` instance to a :obj:`networkx.Graph` if :attr:`to_undirected` is set to :obj:`True`, or a directed :obj:`networkx.DiGraph` otherwise. Args: data (torch_geometric.data.Data): The data object. node_attrs (iterable of str, optional): The node attributes to be copied. (default: :obj:`None`) edge_attrs (iterable of str, optional): The edge attributes to be copied. (default: :obj:`None`) to_undirected (bool, optional): If set to :obj:`True`, will return a a :obj:`networkx.Graph` instead of a :obj:`networkx.DiGraph`. The undirected graph will correspond to the upper triangle of the corresponding adjacency matrix. (default: :obj:`False`) remove_self_loops (bool, optional): If set to :obj:`True`, will not include self loops in the resulting graph. (default: :obj:`False`) """ import networkx as nx if to_undirected: G = nx.Graph() else: G = nx.DiGraph() G.add_nodes_from(range(data.num_nodes)) node_attrs, edge_attrs = node_attrs or [], edge_attrs or [] values = {} for key, item in data(*(node_attrs + edge_attrs)): if torch.is_tensor(item): values[key] = item.squeeze().tolist() else: values[key] = item if isinstance(values[key], (list, tuple)) and len(values[key]) == 1: values[key] = item[0] for i, (u, v) in enumerate(data.edge_index.t().tolist()): if to_undirected and v > u: continue if remove_self_loops and u == v: continue G.add_edge(u, v) for key in edge_attrs: G[u][v][key] = values[key][i] for key in node_attrs: for i, feat_dict in G.nodes(data=True): feat_dict.update({key: values[key][i]}) return G
[docs]def from_networkx(G, group_node_attrs: Optional[Union[List[str], all]] = None, group_edge_attrs: Optional[Union[List[str], all]] = None): r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a :class:`torch_geometric.data.Data` instance. Args: G (networkx.Graph or networkx.DiGraph): A networkx graph. group_node_attrs (List[str] or all, optional): The node attributes to be concatenated and added to :obj:`data.x`. (default: :obj:`None`) group_edge_attrs (List[str] or all, optional): The edge attributes to be concatenated and added to :obj:`data.edge_attr`. (default: :obj:`None`) .. note:: All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must be numeric. """ import networkx as nx G = nx.convert_node_labels_to_integers(G) G = G.to_directed() if not nx.is_directed(G) else G if isinstance(G, (nx.MultiGraph, nx.MultiDiGraph)): edges = list(G.edges(keys=False)) else: edges = list(G.edges) edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() data = defaultdict(list) if G.number_of_nodes() > 0: node_attrs = list(next(iter(G.nodes(data=True)))[-1].keys()) else: node_attrs = {} if G.number_of_edges() > 0: edge_attrs = list(next(iter(G.edges(data=True)))[-1].keys()) else: edge_attrs = {} for i, (_, feat_dict) in enumerate(G.nodes(data=True)): if set(feat_dict.keys()) != set(node_attrs): raise ValueError('Not all nodes contain the same attributes') for key, value in feat_dict.items(): data[str(key)].append(value) for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): if set(feat_dict.keys()) != set(edge_attrs): raise ValueError('Not all edges contain the same attributes') for key, value in feat_dict.items(): key = f'edge_{key}' if key in node_attrs else key data[str(key)].append(value) for key, value in data.items(): try: data[key] = torch.tensor(value) except ValueError: pass data['edge_index'] = edge_index.view(2, -1) data = torch_geometric.data.Data.from_dict(data) if group_node_attrs is all: group_node_attrs = list(node_attrs) if group_node_attrs is not None: xs = [] for key in group_node_attrs: x = data[key] x = x.view(-1, 1) if x.dim() <= 1 else x xs.append(x) del data[key] data.x = torch.cat(xs, dim=-1) if group_edge_attrs is all: group_edge_attrs = list(edge_attrs) if group_edge_attrs is not None: xs = [] for key in group_edge_attrs: key = f'edge_{key}' if key in node_attrs else key x = data[key] x = x.view(-1, 1) if x.dim() <= 1 else x xs.append(x) del data[key] data.edge_attr = torch.cat(xs, dim=-1) if data.x is None and data.pos is None: data.num_nodes = G.number_of_nodes() return data
[docs]def to_trimesh(data): r"""Converts a :class:`torch_geometric.data.Data` instance to a :obj:`trimesh.Trimesh`. Args: data (torch_geometric.data.Data): The data object. """ import trimesh return trimesh.Trimesh(vertices=data.pos.detach().cpu().numpy(), faces=data.face.detach().t().cpu().numpy(), process=False)
[docs]def from_trimesh(mesh): r"""Converts a :obj:`trimesh.Trimesh` to a :class:`torch_geometric.data.Data` instance. Args: mesh (trimesh.Trimesh): A :obj:`trimesh` mesh. """ pos = torch.from_numpy(mesh.vertices).to(torch.float) face = torch.from_numpy(mesh.faces).t().contiguous() return torch_geometric.data.Data(pos=pos, face=face)
[docs]def to_cugraph(edge_index: Tensor, edge_weight: Optional[Tensor] = None, relabel_nodes: bool = True): r"""Converts a graph given by :obj:`edge_index` and optional :obj:`edge_weight` into a :obj:`cugraph` graph object. Args: relabel_nodes (bool, optional): If set to :obj:`True`, :obj:`cugraph` will remove any isolated nodes, leading to a relabeling of nodes. (default: :obj:`True`) """ import cudf import cugraph df = cudf.from_dlpack(to_dlpack(edge_index.t())) if edge_weight is not None: assert edge_weight.dim() == 1 df[2] = cudf.from_dlpack(to_dlpack(edge_weight)) return cugraph.from_cudf_edgelist( df, source=0, destination=1, edge_attr=2 if edge_weight is not None else None, renumber=relabel_nodes)
def from_cugraph(G) -> Tuple[Tensor, Optional[Tensor]]: r"""Converts a :obj:`cugraph` graph object into :obj:`edge_index` and optional :obj:`edge_weight` tensors. """ df = G.edgelist.edgelist_df src = from_dlpack(df['src'].to_dlpack()).long() dst = from_dlpack(df['dst'].to_dlpack()).long() edge_index = torch.stack([src, dst], dim=0) edge_weight = None if 'weights' in df: edge_weight = from_dlpack(df['weights'].to_dlpack()) return edge_index, edge_weight