from typing import List, Literal, Optional, Tuple, Union, overload
import torch
from torch import Tensor
from torch_geometric.typing import OptTensor, PairTensor
from torch_geometric.utils import scatter
from torch_geometric.utils.map import map_index
from torch_geometric.utils.mask import index_to_mask
from torch_geometric.utils.num_nodes import maybe_num_nodes
[docs]def get_num_hops(model: torch.nn.Module) -> int:
r"""Returns the number of hops the model is aggregating information
from.
.. note::
This function counts the number of message passing layers as an
approximation of the total number of hops covered by the model.
Its output may not necessarily be correct in case message passing
layers perform multi-hop aggregation, *e.g.*, as in
:class:`~torch_geometric.nn.conv.ChebConv`.
Example:
>>> class GNN(torch.nn.Module):
... def __init__(self):
... super().__init__()
... self.conv1 = GCNConv(3, 16)
... self.conv2 = GCNConv(16, 16)
... self.lin = Linear(16, 2)
...
... def forward(self, x, edge_index):
... x = self.conv1(x, edge_index).relu()
... x = self.conv2(x, edge_index).relu()
... return self.lin(x)
>>> get_num_hops(GNN())
2
"""
from torch_geometric.nn.conv import MessagePassing
num_hops = 0
for module in model.modules():
if isinstance(module, MessagePassing):
num_hops += 1
return num_hops
@overload
def subgraph(
subset: Union[Tensor, List[int]],
edge_index: Tensor,
edge_attr: OptTensor = ...,
relabel_nodes: bool = ...,
num_nodes: Optional[int] = ...,
) -> Tuple[Tensor, OptTensor]:
pass
@overload
def subgraph(
subset: Union[Tensor, List[int]],
edge_index: Tensor,
edge_attr: OptTensor = ...,
relabel_nodes: bool = ...,
num_nodes: Optional[int] = ...,
*,
return_edge_mask: Literal[False],
) -> Tuple[Tensor, OptTensor]:
pass
@overload
def subgraph(
subset: Union[Tensor, List[int]],
edge_index: Tensor,
edge_attr: OptTensor = ...,
relabel_nodes: bool = ...,
num_nodes: Optional[int] = ...,
*,
return_edge_mask: Literal[True],
) -> Tuple[Tensor, OptTensor, Tensor]:
pass
[docs]def subgraph(
subset: Union[Tensor, List[int]],
edge_index: Tensor,
edge_attr: OptTensor = None,
relabel_nodes: bool = False,
num_nodes: Optional[int] = None,
*,
return_edge_mask: bool = False,
) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, Tensor]]:
r"""Returns the induced subgraph of :obj:`(edge_index, edge_attr)`
containing the nodes in :obj:`subset`.
Args:
subset (LongTensor, BoolTensor or [int]): The nodes to keep.
edge_index (LongTensor): The edge indices.
edge_attr (Tensor, optional): Edge weights or multi-dimensional
edge features. (default: :obj:`None`)
relabel_nodes (bool, optional): If set to :obj:`True`, the resulting
:obj:`edge_index` will be relabeled to hold consecutive indices
starting from zero. (default: :obj:`False`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max(edge_index) + 1`. (default: :obj:`None`)
return_edge_mask (bool, optional): If set to :obj:`True`, will return
the edge mask to filter out additional edge features.
(default: :obj:`False`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
Examples:
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6],
... [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5]])
>>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
>>> subset = torch.tensor([3, 4, 5])
>>> subgraph(subset, edge_index, edge_attr)
(tensor([[3, 4, 4, 5],
[4, 3, 5, 4]]),
tensor([ 7., 8., 9., 10.]))
>>> subgraph(subset, edge_index, edge_attr, return_edge_mask=True)
(tensor([[3, 4, 4, 5],
[4, 3, 5, 4]]),
tensor([ 7., 8., 9., 10.]),
tensor([False, False, False, False, False, False, True,
True, True, True, False, False]))
"""
device = edge_index.device
if isinstance(subset, (list, tuple)):
subset = torch.tensor(subset, dtype=torch.long, device=device)
if subset.dtype != torch.bool:
num_nodes = maybe_num_nodes(edge_index, num_nodes)
node_mask = index_to_mask(subset, size=num_nodes)
else:
num_nodes = subset.size(0)
node_mask = subset
subset = node_mask.nonzero().view(-1)
edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
edge_index = edge_index[:, edge_mask]
edge_attr = edge_attr[edge_mask] if edge_attr is not None else None
if relabel_nodes:
edge_index, _ = map_index(
edge_index.view(-1),
subset,
max_index=num_nodes,
inclusive=True,
)
edge_index = edge_index.view(2, -1)
if return_edge_mask:
return edge_index, edge_attr, edge_mask
else:
return edge_index, edge_attr
[docs]def bipartite_subgraph(
subset: Union[PairTensor, Tuple[List[int], List[int]]],
edge_index: Tensor,
edge_attr: OptTensor = None,
relabel_nodes: bool = False,
size: Optional[Tuple[int, int]] = None,
return_edge_mask: bool = False,
) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, OptTensor]]:
r"""Returns the induced subgraph of the bipartite graph
:obj:`(edge_index, edge_attr)` containing the nodes in :obj:`subset`.
Args:
subset (Tuple[Tensor, Tensor] or tuple([int],[int])): The nodes
to keep.
edge_index (LongTensor): The edge indices.
edge_attr (Tensor, optional): Edge weights or multi-dimensional
edge features. (default: :obj:`None`)
relabel_nodes (bool, optional): If set to :obj:`True`, the resulting
:obj:`edge_index` will be relabeled to hold consecutive indices
starting from zero. (default: :obj:`False`)
size (tuple, optional): The number of nodes.
(default: :obj:`None`)
return_edge_mask (bool, optional): If set to :obj:`True`, will return
the edge mask to filter out additional edge features.
(default: :obj:`False`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
Examples:
>>> edge_index = torch.tensor([[0, 5, 2, 3, 3, 4, 4, 3, 5, 5, 6],
... [0, 0, 3, 2, 0, 0, 2, 1, 2, 3, 1]])
>>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
>>> subset = (torch.tensor([2, 3, 5]), torch.tensor([2, 3]))
>>> bipartite_subgraph(subset, edge_index, edge_attr)
(tensor([[2, 3, 5, 5],
[3, 2, 2, 3]]),
tensor([ 3, 4, 9, 10]))
>>> bipartite_subgraph(subset, edge_index, edge_attr,
... return_edge_mask=True)
(tensor([[2, 3, 5, 5],
[3, 2, 2, 3]]),
tensor([ 3, 4, 9, 10]),
tensor([False, False, True, True, False, False, False, False,
True, True, False]))
"""
device = edge_index.device
src_subset, dst_subset = subset
if not isinstance(src_subset, Tensor):
src_subset = torch.tensor(src_subset, dtype=torch.long, device=device)
if not isinstance(dst_subset, Tensor):
dst_subset = torch.tensor(dst_subset, dtype=torch.long, device=device)
if src_subset.dtype != torch.bool:
src_size = int(edge_index[0].max()) + 1 if size is None else size[0]
src_node_mask = index_to_mask(src_subset, size=src_size)
else:
src_size = src_subset.size(0)
src_node_mask = src_subset
src_subset = src_subset.nonzero().view(-1)
if dst_subset.dtype != torch.bool:
dst_size = int(edge_index[1].max()) + 1 if size is None else size[1]
dst_node_mask = index_to_mask(dst_subset, size=dst_size)
else:
dst_size = dst_subset.size(0)
dst_node_mask = dst_subset
dst_subset = dst_subset.nonzero().view(-1)
edge_mask = src_node_mask[edge_index[0]] & dst_node_mask[edge_index[1]]
edge_index = edge_index[:, edge_mask]
edge_attr = edge_attr[edge_mask] if edge_attr is not None else None
if relabel_nodes:
src_index, _ = map_index(edge_index[0], src_subset, max_index=src_size,
inclusive=True)
dst_index, _ = map_index(edge_index[1], dst_subset, max_index=dst_size,
inclusive=True)
edge_index = torch.stack([src_index, dst_index], dim=0)
if return_edge_mask:
return edge_index, edge_attr, edge_mask
else:
return edge_index, edge_attr
[docs]def k_hop_subgraph(
node_idx: Union[int, List[int], Tensor],
num_hops: int,
edge_index: Tensor,
relabel_nodes: bool = False,
num_nodes: Optional[int] = None,
flow: str = 'source_to_target',
directed: bool = False,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
r"""Computes the induced subgraph of :obj:`edge_index` around all nodes in
:attr:`node_idx` reachable within :math:`k` hops.
The :attr:`flow` argument denotes the direction of edges for finding
:math:`k`-hop neighbors. If set to :obj:`"source_to_target"`, then the
method will find all neighbors that point to the initial set of seed nodes
in :attr:`node_idx.`
This mimics the natural flow of message passing in Graph Neural Networks.
The method returns (1) the nodes involved in the subgraph, (2) the filtered
:obj:`edge_index` connectivity, (3) the mapping from node indices in
:obj:`node_idx` to their new location, and (4) the edge mask indicating
which edges were preserved.
Args:
node_idx (int, list, tuple or :obj:`torch.Tensor`): The central seed
node(s).
num_hops (int): The number of hops :math:`k`.
edge_index (LongTensor): The edge indices.
relabel_nodes (bool, optional): If set to :obj:`True`, the resulting
:obj:`edge_index` will be relabeled to hold consecutive indices
starting from zero. (default: :obj:`False`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
flow (str, optional): The flow direction of :math:`k`-hop aggregation
(:obj:`"source_to_target"` or :obj:`"target_to_source"`).
(default: :obj:`"source_to_target"`)
directed (bool, optional): If set to :obj:`True`, will only include
directed edges to the seed nodes :obj:`node_idx`.
(default: :obj:`False`)
:rtype: (:class:`LongTensor`, :class:`LongTensor`, :class:`LongTensor`,
:class:`BoolTensor`)
Examples:
>>> edge_index = torch.tensor([[0, 1, 2, 3, 4, 5],
... [2, 2, 4, 4, 6, 6]])
>>> # Center node 6, 2-hops
>>> subset, edge_index, mapping, edge_mask = k_hop_subgraph(
... 6, 2, edge_index, relabel_nodes=True)
>>> subset
tensor([2, 3, 4, 5, 6])
>>> edge_index
tensor([[0, 1, 2, 3],
[2, 2, 4, 4]])
>>> mapping
tensor([4])
>>> edge_mask
tensor([False, False, True, True, True, True])
>>> subset[mapping]
tensor([6])
>>> edge_index = torch.tensor([[1, 2, 4, 5],
... [0, 1, 5, 6]])
>>> (subset, edge_index,
... mapping, edge_mask) = k_hop_subgraph([0, 6], 2,
... edge_index,
... relabel_nodes=True)
>>> subset
tensor([0, 1, 2, 4, 5, 6])
>>> edge_index
tensor([[1, 2, 3, 4],
[0, 1, 4, 5]])
>>> mapping
tensor([0, 5])
>>> edge_mask
tensor([True, True, True, True])
>>> subset[mapping]
tensor([0, 6])
"""
num_nodes = maybe_num_nodes(edge_index, num_nodes)
assert flow in ['source_to_target', 'target_to_source']
if flow == 'target_to_source':
row, col = edge_index
else:
col, row = edge_index
node_mask = row.new_empty(num_nodes, dtype=torch.bool)
edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
if isinstance(node_idx, int):
node_idx = torch.tensor([node_idx], device=row.device)
elif isinstance(node_idx, (list, tuple)):
node_idx = torch.tensor(node_idx, device=row.device)
else:
node_idx = node_idx.to(row.device)
subsets = [node_idx]
for _ in range(num_hops):
node_mask.fill_(False)
node_mask[subsets[-1]] = True
torch.index_select(node_mask, 0, row, out=edge_mask)
subsets.append(col[edge_mask])
subset, inv = torch.cat(subsets).unique(return_inverse=True)
inv = inv[:node_idx.numel()]
node_mask.fill_(False)
node_mask[subset] = True
if not directed:
edge_mask = node_mask[row] & node_mask[col]
edge_index = edge_index[:, edge_mask]
if relabel_nodes:
mapping = row.new_full((num_nodes, ), -1)
mapping[subset] = torch.arange(subset.size(0), device=row.device)
edge_index = mapping[edge_index]
return subset, edge_index, inv, edge_mask
@overload
def hyper_subgraph(
subset: Union[Tensor, List[int]],
edge_index: Tensor,
edge_attr: OptTensor = ...,
relabel_nodes: bool = ...,
num_nodes: Optional[int] = ...,
) -> Tuple[Tensor, OptTensor]:
pass
@overload
def hyper_subgraph(
subset: Union[Tensor, List[int]],
edge_index: Tensor,
edge_attr: OptTensor = ...,
relabel_nodes: bool = ...,
num_nodes: Optional[int] = ...,
*,
return_edge_mask: Literal[False],
) -> Tuple[Tensor, OptTensor]:
pass
@overload
def hyper_subgraph(
subset: Union[Tensor, List[int]],
edge_index: Tensor,
edge_attr: OptTensor = ...,
relabel_nodes: bool = ...,
num_nodes: Optional[int] = ...,
*,
return_edge_mask: Literal[True],
) -> Tuple[Tensor, OptTensor, Tensor]:
pass
def hyper_subgraph(
subset: Union[Tensor, List[int]],
edge_index: Tensor,
edge_attr: OptTensor = None,
relabel_nodes: bool = False,
num_nodes: Optional[int] = None,
return_edge_mask: bool = False,
) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, Tensor]]:
r"""Returns the induced subgraph of the hyper graph of
:obj:`(edge_index, edge_attr)` containing the nodes in :obj:`subset`.
Args:
subset (torch.Tensor or [int]): The nodes to keep.
edge_index (LongTensor): Hyperedge tensor
with shape :obj:`[2, num_edges*num_nodes_per_edge]`, where
:obj:`edge_index[1]` denotes the hyperedge index and
:obj:`edge_index[0]` denotes the node indices that are connected
by the hyperedge.
edge_attr (torch.Tensor, optional): Edge weights or multi-dimensional
edge features of shape :obj:`[num_edges, *]`.
(default: :obj:`None`)
relabel_nodes (bool, optional): If set to :obj:`True`, the
resulting :obj:`edge_index` will be relabeled to hold
consecutive indices starting from zero. (default: :obj:`False`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max(edge_index[0]) + 1`. (default: :obj:`None`)
return_edge_mask (bool, optional): If set to :obj:`True`, will return
the edge mask to filter out additional edge features.
(default: :obj:`False`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
Examples:
>>> edge_index = torch.tensor([[0, 1, 2, 1, 2, 3, 0, 2, 3],
... [0, 0, 0, 1, 1, 1, 2, 2, 2]])
>>> edge_attr = torch.tensor([3, 2, 6])
>>> subset = torch.tensor([0, 3])
>>> subgraph(subset, edge_index, edge_attr)
(tensor([[0, 3],
[0, 0]]),
tensor([ 6.]))
>>> subgraph(subset, edge_index, edge_attr, return_edge_mask=True)
(tensor([[0, 3],
[0, 0]]),
tensor([ 6.]))
tensor([False, False, True])
"""
device = edge_index.device
if isinstance(subset, (list, tuple)):
subset = torch.tensor(subset, dtype=torch.long, device=device)
if subset.dtype != torch.bool:
num_nodes = maybe_num_nodes(edge_index, num_nodes)
node_mask = index_to_mask(subset, size=num_nodes)
else:
num_nodes = subset.size(0)
node_mask = subset
# Mask all connections that contain a node not in the subset
hyper_edge_connection_mask = node_mask[
edge_index[0]] # num_edges*num_nodes_per_edge
# Mask hyperedges that contain one or less nodes from the subset
edge_mask = scatter(hyper_edge_connection_mask.to(torch.long),
edge_index[1], reduce='sum') > 1
# Mask connections if hyperedge contains one or less nodes from the subset
# or is connected to a node not in the subset
hyper_edge_connection_mask = hyper_edge_connection_mask & edge_mask[
edge_index[1]]
edge_index = edge_index[:, hyper_edge_connection_mask]
edge_attr = edge_attr[edge_mask] if edge_attr is not None else None
# Relabel edges
edge_idx = torch.zeros(edge_mask.size(0), dtype=torch.long, device=device)
edge_idx[edge_mask] = torch.arange(edge_mask.sum().item(), device=device)
edge_index = torch.cat(
[edge_index[0].unsqueeze(0), edge_idx[edge_index[1]].unsqueeze(0)], 0)
if relabel_nodes:
node_idx = torch.zeros(node_mask.size(0), dtype=torch.long,
device=device)
node_idx[subset] = torch.arange(node_mask.sum().item(), device=device)
edge_index = torch.cat(
[node_idx[edge_index[0]].unsqueeze(0), edge_index[1].unsqueeze(0)],
0)
if return_edge_mask:
return edge_index, edge_attr, edge_mask
else:
return edge_index, edge_attr