Source code for torch_geometric.utils.convert

from collections import defaultdict
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union

import scipy.sparse
import torch
from torch import Tensor
from torch.utils.dlpack import from_dlpack, to_dlpack

import torch_geometric
from torch_geometric.utils.num_nodes import maybe_num_nodes


[docs]def to_scipy_sparse_matrix( edge_index: Tensor, edge_attr: Optional[Tensor] = None, num_nodes: Optional[int] = None, ) -> scipy.sparse.coo_matrix: r"""Converts a graph given by edge indices and edge attributes to a scipy sparse matrix. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) Examples: >>> edge_index = torch.tensor([ ... [0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2], ... ]) >>> to_scipy_sparse_matrix(edge_index) <4x4 sparse matrix of type '<class 'numpy.float32'>' with 6 stored elements in COOrdinate format> """ row, col = edge_index.cpu() if edge_attr is None: edge_attr = torch.ones(row.size(0)) else: edge_attr = edge_attr.view(-1).cpu() assert edge_attr.size(0) == row.size(0) N = maybe_num_nodes(edge_index, num_nodes) out = scipy.sparse.coo_matrix( (edge_attr.numpy(), (row.numpy(), col.numpy())), (N, N)) return out
[docs]def from_scipy_sparse_matrix( A: scipy.sparse.spmatrix) -> Tuple[Tensor, Tensor]: r"""Converts a scipy sparse matrix to edge indices and edge attributes. Args: A (scipy.sparse): A sparse matrix. Examples: >>> edge_index = torch.tensor([ ... [0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2], ... ]) >>> adj = to_scipy_sparse_matrix(edge_index) >>> # `edge_index` and `edge_weight` are both returned >>> from_scipy_sparse_matrix(adj) (tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]), tensor([1., 1., 1., 1., 1., 1.])) """ A = A.tocoo() row = torch.from_numpy(A.row).to(torch.long) col = torch.from_numpy(A.col).to(torch.long) edge_index = torch.stack([row, col], dim=0) edge_weight = torch.from_numpy(A.data) return edge_index, edge_weight
[docs]def to_networkx( data: Union[ 'torch_geometric.data.Data', 'torch_geometric.data.HeteroData', ], node_attrs: Optional[Iterable[str]] = None, edge_attrs: Optional[Iterable[str]] = None, graph_attrs: Optional[Iterable[str]] = None, to_undirected: Optional[Union[bool, str]] = False, to_multi: bool = False, remove_self_loops: bool = False, ) -> Any: r"""Converts a :class:`torch_geometric.data.Data` instance to a :obj:`networkx.Graph` if :attr:`to_undirected` is set to :obj:`True`, or a directed :obj:`networkx.DiGraph` otherwise. Args: data (torch_geometric.data.Data or torch_geometric.data.HeteroData): A homogeneous or heterogeneous data object. node_attrs (iterable of str, optional): The node attributes to be copied. (default: :obj:`None`) edge_attrs (iterable of str, optional): The edge attributes to be copied. (default: :obj:`None`) graph_attrs (iterable of str, optional): The graph attributes to be copied. (default: :obj:`None`) to_undirected (bool or str, optional): If set to :obj:`True`, will return a :class:`networkx.Graph` instead of a :class:`networkx.DiGraph`. By default, will include all edges and make them undirected. If set to :obj:`"upper"`, the undirected graph will only correspond to the upper triangle of the input adjacency matrix. If set to :obj:`"lower"`, the undirected graph will only correspond to the lower triangle of the input adjacency matrix. Only applicable in case the :obj:`data` object holds a homogeneous graph. (default: :obj:`False`) to_multi (bool, optional): if set to :obj:`True`, will return a :class:`networkx.MultiGraph` or a :class:`networkx:MultiDiGraph` (depending on the :obj:`to_undirected` option), which will not drop duplicated edges that may exist in :obj:`data`. (default: :obj:`False`) remove_self_loops (bool, optional): If set to :obj:`True`, will not include self-loops in the resulting graph. (default: :obj:`False`) Examples: >>> edge_index = torch.tensor([ ... [0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2], ... ]) >>> data = Data(edge_index=edge_index, num_nodes=4) >>> to_networkx(data) <networkx.classes.digraph.DiGraph at 0x2713fdb40d0> """ import networkx as nx from torch_geometric.data import HeteroData to_undirected_upper: bool = to_undirected == 'upper' to_undirected_lower: bool = to_undirected == 'lower' to_undirected = to_undirected is True to_undirected |= to_undirected_upper or to_undirected_lower assert isinstance(to_undirected, bool) if isinstance(data, HeteroData) and to_undirected: raise ValueError("'to_undirected' is not supported in " "'to_networkx' for heterogeneous graphs") if to_undirected: G = nx.MultiGraph() if to_multi else nx.Graph() else: G = nx.MultiDiGraph() if to_multi else nx.DiGraph() def to_networkx_value(value: Any) -> Any: return value.tolist() if isinstance(value, Tensor) else value for key in graph_attrs or []: G.graph[key] = to_networkx_value(data[key]) node_offsets = data.node_offsets for node_store in data.node_stores: start = node_offsets[node_store._key] assert node_store.num_nodes is not None for i in range(node_store.num_nodes): node_kwargs: Dict[str, Any] = {} if isinstance(data, HeteroData): node_kwargs['type'] = node_store._key for key in node_attrs or []: node_kwargs[key] = to_networkx_value(node_store[key][i]) G.add_node(start + i, **node_kwargs) for edge_store in data.edge_stores: for i, (v, w) in enumerate(edge_store.edge_index.t().tolist()): if to_undirected_upper and v > w: continue elif to_undirected_lower and v < w: continue elif remove_self_loops and v == w and not edge_store.is_bipartite( ): continue edge_kwargs: Dict[str, Any] = {} if isinstance(data, HeteroData): v = v + node_offsets[edge_store._key[0]] w = w + node_offsets[edge_store._key[-1]] edge_kwargs['type'] = edge_store._key for key in edge_attrs or []: edge_kwargs[key] = to_networkx_value(edge_store[key][i]) G.add_edge(v, w, **edge_kwargs) return G
[docs]def from_networkx( G: Any, group_node_attrs: Optional[Union[List[str], Literal['all']]] = None, group_edge_attrs: Optional[Union[List[str], Literal['all']]] = None, ) -> 'torch_geometric.data.Data': r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a :class:`torch_geometric.data.Data` instance. Args: G (networkx.Graph or networkx.DiGraph): A networkx graph. group_node_attrs (List[str] or "all", optional): The node attributes to be concatenated and added to :obj:`data.x`. (default: :obj:`None`) group_edge_attrs (List[str] or "all", optional): The edge attributes to be concatenated and added to :obj:`data.edge_attr`. (default: :obj:`None`) .. note:: All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must be numeric. Examples: >>> edge_index = torch.tensor([ ... [0, 1, 1, 2, 2, 3], ... [1, 0, 2, 1, 3, 2], ... ]) >>> data = Data(edge_index=edge_index, num_nodes=4) >>> g = to_networkx(data) >>> # A `Data` object is returned >>> from_networkx(g) Data(edge_index=[2, 6], num_nodes=4) """ import networkx as nx from torch_geometric.data import Data G = G.to_directed() if not nx.is_directed(G) else G mapping = dict(zip(G.nodes(), range(G.number_of_nodes()))) edge_index = torch.empty((2, G.number_of_edges()), dtype=torch.long) for i, (src, dst) in enumerate(G.edges()): edge_index[0, i] = mapping[src] edge_index[1, i] = mapping[dst] data_dict: Dict[str, Any] = defaultdict(list) data_dict['edge_index'] = edge_index node_attrs: List[str] = [] if G.number_of_nodes() > 0: node_attrs = list(next(iter(G.nodes(data=True)))[-1].keys()) edge_attrs: List[str] = [] if G.number_of_edges() > 0: edge_attrs = list(next(iter(G.edges(data=True)))[-1].keys()) if group_node_attrs is not None and not isinstance(group_node_attrs, list): group_node_attrs = node_attrs if group_edge_attrs is not None and not isinstance(group_edge_attrs, list): group_edge_attrs = edge_attrs for i, (_, feat_dict) in enumerate(G.nodes(data=True)): if set(feat_dict.keys()) != set(node_attrs): raise ValueError('Not all nodes contain the same attributes') for key, value in feat_dict.items(): data_dict[str(key)].append(value) for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): if set(feat_dict.keys()) != set(edge_attrs): raise ValueError('Not all edges contain the same attributes') for key, value in feat_dict.items(): key = f'edge_{key}' if key in node_attrs else key data_dict[str(key)].append(value) for key, value in G.graph.items(): if key == 'node_default' or key == 'edge_default': continue # Do not load default attributes. key = f'graph_{key}' if key in node_attrs else key data_dict[str(key)] = value for key, value in data_dict.items(): if isinstance(value, (tuple, list)) and isinstance(value[0], Tensor): data_dict[key] = torch.stack(value, dim=0) else: try: data_dict[key] = torch.as_tensor(value) except Exception: pass data = Data.from_dict(data_dict) if group_node_attrs is not None: xs = [] for key in group_node_attrs: x = data[key] x = x.view(-1, 1) if x.dim() <= 1 else x xs.append(x) del data[key] data.x = torch.cat(xs, dim=-1) if group_edge_attrs is not None: xs = [] for key in group_edge_attrs: key = f'edge_{key}' if key in node_attrs else key x = data[key] x = x.view(-1, 1) if x.dim() <= 1 else x xs.append(x) del data[key] data.edge_attr = torch.cat(xs, dim=-1) if data.x is None and data.pos is None: data.num_nodes = G.number_of_nodes() return data
[docs]def to_networkit( edge_index: Tensor, edge_weight: Optional[Tensor] = None, num_nodes: Optional[int] = None, directed: bool = True, ) -> Any: r"""Converts a :obj:`(edge_index, edge_weight)` tuple to a :class:`networkit.Graph`. Args: edge_index (torch.Tensor): The edge indices of the graph. edge_weight (torch.Tensor, optional): The edge weights of the graph. (default: :obj:`None`) num_nodes (int, optional): The number of nodes in the graph. (default: :obj:`None`) directed (bool, optional): If set to :obj:`False`, the graph will be undirected. (default: :obj:`True`) """ import networkit as nk num_nodes = maybe_num_nodes(edge_index, num_nodes) g = nk.graph.Graph( num_nodes, weighted=edge_weight is not None, directed=directed, ) if edge_weight is None: edge_weight = torch.ones(edge_index.size(1)) if not directed: mask = edge_index[0] <= edge_index[1] edge_index = edge_index[:, mask] edge_weight = edge_weight[mask] for (u, v), w in zip(edge_index.t().tolist(), edge_weight.tolist()): g.addEdge(u, v, w) return g
[docs]def from_networkit(g: Any) -> Tuple[Tensor, Optional[Tensor]]: r"""Converts a :class:`networkit.Graph` to a :obj:`(edge_index, edge_weight)` tuple. If the :class:`networkit.Graph` is not weighted, the returned :obj:`edge_weight` will be :obj:`None`. Args: g (networkkit.graph.Graph): A :obj:`networkit` graph object. """ is_directed = g.isDirected() is_weighted = g.isWeighted() edge_indices, edge_weights = [], [] for u, v, w in g.iterEdgesWeights(): edge_indices.append([u, v]) edge_weights.append(w) if not is_directed: edge_indices.append([v, u]) edge_weights.append(w) edge_index = torch.tensor(edge_indices).t().contiguous() edge_weight = torch.tensor(edge_weights) if is_weighted else None return edge_index, edge_weight
[docs]def to_trimesh(data: 'torch_geometric.data.Data') -> Any: r"""Converts a :class:`torch_geometric.data.Data` instance to a :obj:`trimesh.Trimesh`. Args: data (torch_geometric.data.Data): The data object. Example: >>> pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], ... dtype=torch.float) >>> face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t() >>> data = Data(pos=pos, face=face) >>> to_trimesh(data) <trimesh.Trimesh(vertices.shape=(4, 3), faces.shape=(2, 3))> """ import trimesh assert data.pos is not None assert data.face is not None return trimesh.Trimesh( vertices=data.pos.detach().cpu().numpy(), faces=data.face.detach().t().cpu().numpy(), process=False, )
[docs]def from_trimesh(mesh: Any) -> 'torch_geometric.data.Data': r"""Converts a :obj:`trimesh.Trimesh` to a :class:`torch_geometric.data.Data` instance. Args: mesh (trimesh.Trimesh): A :obj:`trimesh` mesh. Example: >>> pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], ... dtype=torch.float) >>> face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t() >>> data = Data(pos=pos, face=face) >>> mesh = to_trimesh(data) >>> from_trimesh(mesh) Data(pos=[4, 3], face=[3, 2]) """ from torch_geometric.data import Data pos = torch.from_numpy(mesh.vertices).to(torch.float) face = torch.from_numpy(mesh.faces).t().contiguous() return Data(pos=pos, face=face)
[docs]def to_cugraph( edge_index: Tensor, edge_weight: Optional[Tensor] = None, relabel_nodes: bool = True, directed: bool = True, ) -> Any: r"""Converts a graph given by :obj:`edge_index` and optional :obj:`edge_weight` into a :obj:`cugraph` graph object. Args: edge_index (torch.Tensor): The edge indices of the graph. edge_weight (torch.Tensor, optional): The edge weights of the graph. (default: :obj:`None`) relabel_nodes (bool, optional): If set to :obj:`True`, :obj:`cugraph` will remove any isolated nodes, leading to a relabeling of nodes. (default: :obj:`True`) directed (bool, optional): If set to :obj:`False`, the graph will be undirected. (default: :obj:`True`) """ import cudf import cugraph g = cugraph.Graph(directed=directed) df = cudf.from_dlpack(to_dlpack(edge_index.t())) if edge_weight is not None: assert edge_weight.dim() == 1 df['2'] = cudf.from_dlpack(to_dlpack(edge_weight)) g.from_cudf_edgelist( df, source=0, destination=1, edge_attr='2' if edge_weight is not None else None, renumber=relabel_nodes, ) return g
[docs]def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]: r"""Converts a :obj:`cugraph` graph object into :obj:`edge_index` and optional :obj:`edge_weight` tensors. Args: g (cugraph.Graph): A :obj:`cugraph` graph object. """ df = g.view_edge_list() src = from_dlpack(df[0].to_dlpack()).long() dst = from_dlpack(df[1].to_dlpack()).long() edge_index = torch.stack([src, dst], dim=0) edge_weight = None if '2' in df: edge_weight = from_dlpack(df['2'].to_dlpack()) return edge_index, edge_weight
[docs]def to_dgl( data: Union['torch_geometric.data.Data', 'torch_geometric.data.HeteroData'] ) -> Any: r"""Converts a :class:`torch_geometric.data.Data` or :class:`torch_geometric.data.HeteroData` instance to a :obj:`dgl` graph object. Args: data (torch_geometric.data.Data or torch_geometric.data.HeteroData): The data object. Example: >>> edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], [1, 0, 2, 1, 4, 4]]) >>> x = torch.randn(5, 3) >>> edge_attr = torch.randn(6, 2) >>> data = Data(x=x, edge_index=edge_index, edge_attr=y) >>> g = to_dgl(data) >>> g Graph(num_nodes=5, num_edges=6, ndata_schemes={'x': Scheme(shape=(3,))} edata_schemes={'edge_attr': Scheme(shape=(2, ))}) >>> data = HeteroData() >>> data['paper'].x = torch.randn(5, 3) >>> data['author'].x = torch.ones(5, 3) >>> edge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) >>> data['author', 'cites', 'paper'].edge_index = edge_index >>> g = to_dgl(data) >>> g Graph(num_nodes={'author': 5, 'paper': 5}, num_edges={('author', 'cites', 'paper'): 5}, metagraph=[('author', 'paper', 'cites')]) """ import dgl from torch_geometric.data import Data, HeteroData if isinstance(data, Data): if data.edge_index is not None: row, col = data.edge_index else: row, col, _ = data.adj_t.t().coo() g = dgl.graph((row, col)) for attr in data.node_attrs(): g.ndata[attr] = data[attr] for attr in data.edge_attrs(): if attr in ['edge_index', 'adj_t']: continue g.edata[attr] = data[attr] return g if isinstance(data, HeteroData): data_dict = {} for edge_type, edge_store in data.edge_items(): if edge_store.get('edge_index') is not None: row, col = edge_store.edge_index else: row, col, _ = edge_store['adj_t'].t().coo() data_dict[edge_type] = (row, col) g = dgl.heterograph(data_dict) for node_type, node_store in data.node_items(): for attr, value in node_store.items(): g.nodes[node_type].data[attr] = value for edge_type, edge_store in data.edge_items(): for attr, value in edge_store.items(): if attr in ['edge_index', 'adj_t']: continue g.edges[edge_type].data[attr] = value return g raise ValueError(f"Invalid data type (got '{type(data)}')")
[docs]def from_dgl( g: Any, ) -> Union['torch_geometric.data.Data', 'torch_geometric.data.HeteroData']: r"""Converts a :obj:`dgl` graph object to a :class:`torch_geometric.data.Data` or :class:`torch_geometric.data.HeteroData` instance. Args: g (dgl.DGLGraph): The :obj:`dgl` graph object. Example: >>> g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) >>> g.ndata['x'] = torch.randn(g.num_nodes(), 3) >>> g.edata['edge_attr'] = torch.randn(g.num_edges(), 2) >>> data = from_dgl(g) >>> data Data(x=[6, 3], edge_attr=[4, 2], edge_index=[2, 4]) >>> g = dgl.heterograph({ >>> g = dgl.heterograph({ ... ('author', 'writes', 'paper'): ([0, 1, 1, 2, 3, 3, 4], ... [0, 0, 1, 1, 1, 2, 2])}) >>> g.nodes['author'].data['x'] = torch.randn(5, 3) >>> g.nodes['paper'].data['x'] = torch.randn(5, 3) >>> data = from_dgl(g) >>> data HeteroData( author={ x=[5, 3] }, paper={ x=[3, 3] }, (author, writes, paper)={ edge_index=[2, 7] } ) """ import dgl from torch_geometric.data import Data, HeteroData if not isinstance(g, dgl.DGLGraph): raise ValueError(f"Invalid data type (got '{type(g)}')") data: Union[Data, HeteroData] if g.is_homogeneous: data = Data() data.edge_index = torch.stack(g.edges(), dim=0) for attr, value in g.ndata.items(): data[attr] = value for attr, value in g.edata.items(): data[attr] = value return data data = HeteroData() for node_type in g.ntypes: for attr, value in g.nodes[node_type].data.items(): data[node_type][attr] = value for edge_type in g.canonical_etypes: row, col = g.edges(form="uv", etype=edge_type) data[edge_type].edge_index = torch.stack([row, col], dim=0) for attr, value in g.edge_attr_schemes(edge_type).items(): data[edge_type][attr] = value return data