from torch_sparse import SparseTensor
[docs]class ToSparseTensor(object):
r"""Converts the :obj:`edge_index` attribute of a data object into a
(transposed) :class:`torch_sparse.SparseTensor` type with key
:obj:`adj_.t`.
.. note::
In case of composing multiple transforms, it is best to convert the
:obj:`data` object to a :obj:`SparseTensor` as late as possible, since
there exist some transforms that are only able to operate on
:obj:`data.edge_index` for now.
Args:
remove_faces (bool, optional): If set to :obj:`False`, the
:obj:`edge_index` tensor will not be removed.
(default: :obj:`True`)
fill_cache (bool, optional): If set to :obj:`False`, will not
fill the underlying :obj:`SparseTensor` cache.
(default: :obj:`True`)
"""
def __init__(self, remove_edge_index: bool = True,
fill_cache: bool = True):
self.remove_edge_index = remove_edge_index
self.fill_cache = fill_cache
def __call__(self, data):
assert data.edge_index is not None
(row, col), N, E = data.edge_index, data.num_nodes, data.num_edges
perm = (col * N + row).argsort()
row, col = row[perm], col[perm]
if self.remove_edge_index:
data.edge_index = None
value = None
for key in ['edge_weight', 'edge_attr', 'edge_type']:
if data[key] is not None:
value = data[key][perm]
if self.remove_edge_index:
data[key] = None
break
for key, item in data:
if item.size(0) == E:
data[key] = item[perm]
data.adj_t = SparseTensor(row=col, col=row, value=value,
sparse_sizes=(N, N), is_sorted=True)
if self.fill_cache: # Pre-process some important attributes.
data.adj_t.storage.rowptr()
data.adj_t.storage.csr2csc()
return data
def __repr__(self):
return f'{self.__class__.__name__}()'