# Source code for torch_geometric.transforms.line_graph

import torch
from torch_sparse import coalesce

from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import remove_self_loops

[docs]class LineGraph(BaseTransform):
r"""Converts a graph to its corresponding line-graph:

.. math::
L(\mathcal{G}) &= (\mathcal{V}^{\prime}, \mathcal{E}^{\prime})

\mathcal{V}^{\prime} &= \mathcal{E}

\mathcal{E}^{\prime} &= \{ (e_1, e_2) : e_1 \cap e_2 \neq \emptyset \}

Line-graph node indices are equal to indices in the original graph's
coalesced :obj:edge_index.
For undirected graphs, the maximum line-graph node index is
:obj:(data.edge_index.size(1) // 2) - 1.

New node features are given by old edge attributes.
For undirected graphs, edge attributes for reciprocal edges
:obj:(row, col) and :obj:(col, row) get summed together.

Args:
force_directed (bool, optional): If set to :obj:True, the graph will
be always treated as a directed graph. (default: :obj:False)
"""
def __init__(self, force_directed: bool = False):
self.force_directed = force_directed

def __call__(self, data: Data) -> Data:
N = data.num_nodes
edge_index, edge_attr = data.edge_index, data.edge_attr
(row, col), edge_attr = coalesce(edge_index, edge_attr, N, N)

if self.force_directed or data.is_directed():
i = torch.arange(row.size(0), dtype=torch.long, device=row.device)

dim_size=data.num_nodes)
cumsum = torch.cat([count.new_zeros(1), count.cumsum(0)], dim=0)

cols = [
i[cumsum[col[j]]:cumsum[col[j] + 1]]
for j in range(col.size(0))
]
rows = [row.new_full((c.numel(), ), j) for j, c in enumerate(cols)]

row, col = torch.cat(rows, dim=0), torch.cat(cols, dim=0)

data.edge_index = torch.stack([row, col], dim=0)
data.x = data.edge_attr
data.num_nodes = edge_index.size(1)

else:
# Compute node indices.
i = torch.arange(row.size(0), dtype=torch.long, device=row.device)

(row, col), i = coalesce(
torch.stack([
torch.cat([row, col], dim=0),
torch.cat([col, row], dim=0)
], dim=0), torch.cat([i, i], dim=0), N, N)

# Compute new edge indices according to i.
dim_size=data.num_nodes)
joints = torch.split(i, count.tolist())

def generate_grid(x):
row = x.view(-1, 1).repeat(1, x.numel()).view(-1)
col = x.repeat(x.numel())

joints = [generate_grid(joint) for joint in joints]
joints = torch.cat(joints, dim=1)
joints, _ = remove_self_loops(joints)
N = row.size(0) // 2
joints, _ = coalesce(joints, None, N, N)

if edge_attr is not None:
data.x = scatter_add(edge_attr, i, dim=0, dim_size=N)
data.edge_index = joints
data.num_nodes = edge_index.size(1) // 2

data.edge_attr = None
return data