Source code for torch_geometric.graphgym.model_builder

import torch

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import GNN

import torch_geometric.graphgym.register as register

network_dict = {
    'gnn': GNN,
register.network_dict = {**register.network_dict, **network_dict}

[docs]def create_model(to_device=True, dim_in=None, dim_out=None): r""" Create model for graph machine learning Args: to_device (string): The devide that the model will be transferred to dim_in (int, optional): Input dimension to the model dim_out (int, optional): Output dimension to the model """ dim_in = cfg.share.dim_in if dim_in is None else dim_in dim_out = cfg.share.dim_out if dim_out is None else dim_out # binary classification, output dim = 1 if 'classification' in cfg.dataset.task_type and dim_out == 2: dim_out = 1 model = register.network_dict[cfg.model.type](dim_in=dim_in, dim_out=dim_out) if to_device: return model