Source code for torch_geometric.datasets.graph_generator.tree_graph

from typing import List, Optional, Tuple

import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.datasets.graph_generator import GraphGenerator
from torch_geometric.utils import to_undirected


def tree(
    depth: int,
    branch: int = 2,
    undirected: bool = False,
    device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
    """Generates a tree graph with the given depth and branch size, along with
    node-level depth indicators.

    Args:
        depth (int): The depth of the tree.
        branch (int, optional): The branch size of the tree.
            (default: :obj:`2`)
        undirected (bool, optional): If set to :obj:`True`, the tree graph will
            be undirected. (default: :obj:`False`)
        device (torch.device, optional): The desired device of the returned
            tensors. (default: :obj:`None`)
    """
    edges: List[Tuple[int, int]] = []
    depths: List[int] = [0]

    def add_edges(node: int, current_depth: int) -> None:
        node_count = len(depths)

        if current_depth < depth:
            for i in range(branch):
                edges.append((node, node_count + i))
                depths.append(current_depth + 1)

            for i in range(branch):
                add_edges(node=node_count + i, current_depth=current_depth + 1)

    add_edges(node=0, current_depth=0)

    edge_index = torch.tensor(edges, device=device).t().contiguous()
    if undirected:
        edge_index = to_undirected(edge_index, num_nodes=len(depths))

    return edge_index, torch.tensor(depths, device=device)


[docs]class TreeGraph(GraphGenerator): r"""Generates tree graphs. Args: depth (int): The depth of the tree. branch (int, optional): The branch size of the tree. (default: :obj:`2`) undirected (bool, optional): If set to :obj:`True`, the tree graph will be undirected. (default: :obj:`False`) """ def __init__( self, depth: int, branch: int = 2, undirected: bool = False, ) -> None: super().__init__() self.depth = depth self.branch = branch self.undirected = undirected def __call__(self) -> Data: edge_index, depth = tree(self.depth, self.branch, self.undirected) num_nodes = depth.numel() return Data(edge_index=edge_index, depth=depth, num_nodes=num_nodes) def __repr__(self) -> str: return (f'{self.__class__.__name__}(depth={self.depth}, ' f'branch={self.branch}, undirected={self.undirected})')