Source code for torch_geometric.utils.subgraph

from typing import List, Optional, Union

import torch
from torch import Tensor

from .num_nodes import maybe_num_nodes


[docs]def get_num_hops(model: torch.nn.Module) -> int: r"""Returns the number of hops the model is aggregating information from.""" from torch_geometric.nn.conv import MessagePassing num_hops = 0 for module in model.modules(): if isinstance(module, MessagePassing): num_hops += 1 return num_hops
[docs]def subgraph(subset: Union[Tensor, List[int]], edge_index: Tensor, edge_attr: Optional[Tensor] = None, relabel_nodes: bool = False, num_nodes: Optional[int] = None, return_edge_mask: bool = False): r"""Returns the induced subgraph of :obj:`(edge_index, edge_attr)` containing the nodes in :obj:`subset`. Args: subset (LongTensor, BoolTensor or [int]): The nodes to keep. edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) relabel_nodes (bool, optional): If set to :obj:`True`, the resulting :obj:`edge_index` will be relabeled to hold consecutive indices starting from zero. (default: :obj:`False`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) return_edge_mask (bool, optional): If set to :obj:`True`, will return the edge mask to filter out additional edge features. (default: :obj:`False`) :rtype: (:class:`LongTensor`, :class:`Tensor`) """ device = edge_index.device if isinstance(subset, (list, tuple)): subset = torch.tensor(subset, dtype=torch.long, device=device) if subset.dtype == torch.bool or subset.dtype == torch.uint8: node_mask = subset num_nodes = node_mask.size(0) if relabel_nodes: node_idx = torch.zeros(node_mask.size(0), dtype=torch.long, device=device) node_idx[subset] = torch.arange(subset.sum().item(), device=device) else: num_nodes = maybe_num_nodes(edge_index, num_nodes) node_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device) node_mask[subset] = 1 if relabel_nodes: node_idx = torch.zeros(num_nodes, dtype=torch.long, device=device) node_idx[subset] = torch.arange(subset.size(0), device=device) edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]] edge_index = edge_index[:, edge_mask] edge_attr = edge_attr[edge_mask] if edge_attr is not None else None if relabel_nodes: edge_index = node_idx[edge_index] if return_edge_mask: return edge_index, edge_attr, edge_mask else: return edge_index, edge_attr
[docs]def k_hop_subgraph(node_idx, num_hops, edge_index, relabel_nodes=False, num_nodes=None, flow='source_to_target'): r"""Computes the :math:`k`-hop subgraph of :obj:`edge_index` around node :attr:`node_idx`. It returns (1) the nodes involved in the subgraph, (2) the filtered :obj:`edge_index` connectivity, (3) the mapping from node indices in :obj:`node_idx` to their new location, and (4) the edge mask indicating which edges were preserved. Args: node_idx (int, list, tuple or :obj:`torch.Tensor`): The central node(s). num_hops: (int): The number of hops :math:`k`. edge_index (LongTensor): The edge indices. relabel_nodes (bool, optional): If set to :obj:`True`, the resulting :obj:`edge_index` will be relabeled to hold consecutive indices starting from zero. (default: :obj:`False`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) flow (string, optional): The flow direction of :math:`k`-hop aggregation (:obj:`"source_to_target"` or :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) :rtype: (:class:`LongTensor`, :class:`LongTensor`, :class:`LongTensor`, :class:`BoolTensor`) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) assert flow in ['source_to_target', 'target_to_source'] if flow == 'target_to_source': row, col = edge_index else: col, row = edge_index node_mask = row.new_empty(num_nodes, dtype=torch.bool) edge_mask = row.new_empty(row.size(0), dtype=torch.bool) if isinstance(node_idx, (int, list, tuple)): node_idx = torch.tensor([node_idx], device=row.device).flatten() else: node_idx = node_idx.to(row.device) subsets = [node_idx] for _ in range(num_hops): node_mask.fill_(False) node_mask[subsets[-1]] = True torch.index_select(node_mask, 0, row, out=edge_mask) subsets.append(col[edge_mask]) subset, inv = torch.cat(subsets).unique(return_inverse=True) inv = inv[:node_idx.numel()] node_mask.fill_(False) node_mask[subset] = True edge_mask = node_mask[row] & node_mask[col] edge_index = edge_index[:, edge_mask] if relabel_nodes: node_idx = row.new_full((num_nodes, ), -1) node_idx[subset] = torch.arange(subset.size(0), device=row.device) edge_index = node_idx[edge_index] return subset, edge_index, inv, edge_mask