import torch
[docs]def dense_to_sparse(adj):
r"""Converts a dense adjacency matrix to a sparse adjacency matrix defined
by edge indices and edge attributes.
Args:
adj (Tensor): The dense adjacency matrix.
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
assert adj.dim() >= 2 and adj.dim() <= 3
assert adj.size(-1) == adj.size(-2)
index = adj.nonzero(as_tuple=True)
edge_attr = adj[index]
if len(index) == 3:
batch = index[0] * adj.size(-1)
index = (batch + index[1], batch + index[2])
return torch.stack(index, dim=0), edge_attr