Source code for torch_geometric.transforms.two_hop

import torch
from torch_sparse import coalesce, spspmm

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import remove_self_loops


[docs]@functional_transform('two_hop') class TwoHop(BaseTransform): r"""Adds the two hop edges to the edge indices (functional name: :obj:`two_hop`).""" def __call__(self, data: Data) -> Data: edge_index, edge_attr = data.edge_index, data.edge_attr N = data.num_nodes value = edge_index.new_ones((edge_index.size(1), ), dtype=torch.float) index, value = spspmm(edge_index, value, edge_index, value, N, N, N) value.fill_(0) index, value = remove_self_loops(index, value) edge_index = torch.cat([edge_index, index], dim=1) if edge_attr is None: data.edge_index, _ = coalesce(edge_index, None, N, N) else: value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)]) value = value.expand(-1, *list(edge_attr.size())[1:]) edge_attr = torch.cat([edge_attr, value], dim=0) data.edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N) data.edge_attr = edge_attr return data