from typing import Optional, Union, Tuple, List
from collections import defaultdict
import torch
import scipy.sparse
from torch import Tensor
from torch.utils.dlpack import to_dlpack, from_dlpack
import torch_geometric.data
from .num_nodes import maybe_num_nodes
[docs]def to_scipy_sparse_matrix(edge_index, edge_attr=None, num_nodes=None):
r"""Converts a graph given by edge indices and edge attributes to a scipy
sparse matrix.
Args:
edge_index (LongTensor): The edge indices.
edge_attr (Tensor, optional): Edge weights or multi-dimensional
edge features. (default: :obj:`None`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
"""
row, col = edge_index.cpu()
if edge_attr is None:
edge_attr = torch.ones(row.size(0))
else:
edge_attr = edge_attr.view(-1).cpu()
assert edge_attr.size(0) == row.size(0)
N = maybe_num_nodes(edge_index, num_nodes)
out = scipy.sparse.coo_matrix(
(edge_attr.numpy(), (row.numpy(), col.numpy())), (N, N))
return out
[docs]def from_scipy_sparse_matrix(A):
r"""Converts a scipy sparse matrix to edge indices and edge attributes.
Args:
A (scipy.sparse): A sparse matrix.
"""
A = A.tocoo()
row = torch.from_numpy(A.row).to(torch.long)
col = torch.from_numpy(A.col).to(torch.long)
edge_index = torch.stack([row, col], dim=0)
edge_weight = torch.from_numpy(A.data)
return edge_index, edge_weight
[docs]def to_networkx(data, node_attrs=None, edge_attrs=None, to_undirected=False,
remove_self_loops=False):
r"""Converts a :class:`torch_geometric.data.Data` instance to a
:obj:`networkx.Graph` if :attr:`to_undirected` is set to :obj:`True`, or
a directed :obj:`networkx.DiGraph` otherwise.
Args:
data (torch_geometric.data.Data): The data object.
node_attrs (iterable of str, optional): The node attributes to be
copied. (default: :obj:`None`)
edge_attrs (iterable of str, optional): The edge attributes to be
copied. (default: :obj:`None`)
to_undirected (bool, optional): If set to :obj:`True`, will return a
a :obj:`networkx.Graph` instead of a :obj:`networkx.DiGraph`. The
undirected graph will correspond to the upper triangle of the
corresponding adjacency matrix. (default: :obj:`False`)
remove_self_loops (bool, optional): If set to :obj:`True`, will not
include self loops in the resulting graph. (default: :obj:`False`)
"""
import networkx as nx
if to_undirected:
G = nx.Graph()
else:
G = nx.DiGraph()
G.add_nodes_from(range(data.num_nodes))
node_attrs, edge_attrs = node_attrs or [], edge_attrs or []
values = {}
for key, item in data(*(node_attrs + edge_attrs)):
if torch.is_tensor(item):
values[key] = item.squeeze().tolist()
else:
values[key] = item
if isinstance(values[key], (list, tuple)) and len(values[key]) == 1:
values[key] = item[0]
for i, (u, v) in enumerate(data.edge_index.t().tolist()):
if to_undirected and v > u:
continue
if remove_self_loops and u == v:
continue
G.add_edge(u, v)
for key in edge_attrs:
G[u][v][key] = values[key][i]
for key in node_attrs:
for i, feat_dict in G.nodes(data=True):
feat_dict.update({key: values[key][i]})
return G
[docs]def from_networkx(G, group_node_attrs: Optional[Union[List[str], all]] = None,
group_edge_attrs: Optional[Union[List[str], all]] = None):
r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a
:class:`torch_geometric.data.Data` instance.
Args:
G (networkx.Graph or networkx.DiGraph): A networkx graph.
group_node_attrs (List[str] or all, optional): The node attributes to
be concatenated and added to :obj:`data.x`. (default: :obj:`None`)
group_edge_attrs (List[str] or all, optional): The edge attributes to
be concatenated and added to :obj:`data.edge_attr`.
(default: :obj:`None`)
.. note::
All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must
be numeric.
"""
import networkx as nx
G = nx.convert_node_labels_to_integers(G)
G = G.to_directed() if not nx.is_directed(G) else G
edge_index = torch.LongTensor(list(G.edges)).t().contiguous()
data = defaultdict(list)
if G.number_of_nodes() > 0:
node_attrs = list(next(iter(G.nodes(data=True)))[-1].keys())
else:
node_attrs = {}
if G.number_of_edges() > 0:
edge_attrs = list(next(iter(G.edges(data=True)))[-1].keys())
else:
edge_attrs = {}
for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
if set(feat_dict.keys()) != set(node_attrs):
raise ValueError('Not all nodes contain the same attributes')
for key, value in feat_dict.items():
data[str(key)].append(value)
for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
if set(feat_dict.keys()) != set(edge_attrs):
raise ValueError('Not all edges contain the same attributes')
for key, value in feat_dict.items():
data[str(key)].append(value)
for key, value in data.items():
try:
data[key] = torch.tensor(value)
except ValueError:
pass
data['edge_index'] = edge_index.view(2, -1)
data = torch_geometric.data.Data.from_dict(data)
if data.x is None:
data.num_nodes = G.number_of_nodes()
if group_node_attrs is all:
group_node_attrs = list(node_attrs)
if group_node_attrs is not None:
xs = [data[key] for key in group_node_attrs]
xs = [x.view(-1, 1) if x.dim() <= 1 else x for x in xs]
data.x = torch.cat(xs, dim=-1)
if group_edge_attrs is all:
group_edge_attrs = list(edge_attrs)
if group_edge_attrs is not None:
edge_attrs = [data[key] for key in group_edge_attrs]
edge_attrs = [x.view(-1, 1) if x.dim() <= 1 else x for x in edge_attrs]
data.edge_attr = torch.cat(edge_attrs, dim=-1)
return data
[docs]def to_trimesh(data):
r"""Converts a :class:`torch_geometric.data.Data` instance to a
:obj:`trimesh.Trimesh`.
Args:
data (torch_geometric.data.Data): The data object.
"""
import trimesh
return trimesh.Trimesh(vertices=data.pos.detach().cpu().numpy(),
faces=data.face.detach().t().cpu().numpy(),
process=False)
[docs]def from_trimesh(mesh):
r"""Converts a :obj:`trimesh.Trimesh` to a
:class:`torch_geometric.data.Data` instance.
Args:
mesh (trimesh.Trimesh): A :obj:`trimesh` mesh.
"""
pos = torch.from_numpy(mesh.vertices).to(torch.float)
face = torch.from_numpy(mesh.faces).t().contiguous()
return torch_geometric.data.Data(pos=pos, face=face)
[docs]def to_cugraph(edge_index: Tensor, edge_weight: Optional[Tensor] = None,
relabel_nodes: bool = True):
r"""Converts a graph given by :obj:`edge_index` and optional
:obj:`edge_weight` into a :obj:`cugraph` graph object.
Args:
relabel_nodes (bool, optional): If set to :obj:`True`,
:obj:`cugraph` will remove any isolated nodes, leading to a
relabeling of nodes. (default: :obj:`True`)
"""
import cudf
import cugraph
df = cudf.from_dlpack(to_dlpack(edge_index.t()))
if edge_weight is not None:
assert edge_weight.dim() == 1
df[2] = cudf.from_dlpack(to_dlpack(edge_weight))
return cugraph.from_cudf_edgelist(
df, source=0, destination=1,
edge_attr=2 if edge_weight is not None else None,
renumber=relabel_nodes)
def from_cugraph(G) -> Tuple[Tensor, Optional[Tensor]]:
r"""Converts a :obj:`cugraph` graph object into :obj:`edge_index` and
optional :obj:`edge_weight` tensors.
"""
df = G.edgelist.edgelist_df
src = from_dlpack(df['src'].to_dlpack()).long()
dst = from_dlpack(df['dst'].to_dlpack()).long()
edge_index = torch.stack([src, dst], dim=0)
edge_weight = None
if 'weights' in df:
edge_weight = from_dlpack(df['weights'].to_dlpack())
return edge_index, edge_weight