Source code for torch_geometric.graphgym.models.gnn

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.init import init_weights
from torch_geometric.graphgym.models.layer import (
    BatchNorm1dNode,
    GeneralLayer,
    GeneralMultiLayer,
    new_layer_config,
)
from torch_geometric.graphgym.register import register_stage


[docs]def GNNLayer(dim_in, dim_out, has_act=True): """ Wrapper for a GNN layer Args: dim_in (int): Input dimension dim_out (int): Output dimension has_act (bool): Whether has activation function after the layer """ return GeneralLayer( cfg.gnn.layer_type, layer_config=new_layer_config(dim_in, dim_out, 1, has_act=has_act, has_bias=False, cfg=cfg))
[docs]def GNNPreMP(dim_in, dim_out, num_layers): """ Wrapper for NN layer before GNN message passing Args: dim_in (int): Input dimension dim_out (int): Output dimension num_layers (int): Number of layers """ return GeneralMultiLayer( 'linear', layer_config=new_layer_config(dim_in, dim_out, num_layers, has_act=False, has_bias=False, cfg=cfg))
[docs]@register_stage('stack') @register_stage('skipsum') @register_stage('skipconcat') class GNNStackStage(nn.Module): """ Simple Stage that stack GNN layers Args: dim_in (int): Input dimension dim_out (int): Output dimension num_layers (int): Number of GNN layers """ def __init__(self, dim_in, dim_out, num_layers): super().__init__() self.num_layers = num_layers for i in range(num_layers): if cfg.gnn.stage_type == 'skipconcat': d_in = dim_in if i == 0 else dim_in + i * dim_out else: d_in = dim_in if i == 0 else dim_out layer = GNNLayer(d_in, dim_out) self.add_module('layer{}'.format(i), layer) def forward(self, batch): """""" for i, layer in enumerate(self.children()): x = batch.x batch = layer(batch) if cfg.gnn.stage_type == 'skipsum': batch.x = x + batch.x elif cfg.gnn.stage_type == 'skipconcat' and \ i < self.num_layers - 1: batch.x = torch.cat([x, batch.x], dim=1) if cfg.gnn.l2norm: batch.x = F.normalize(batch.x, p=2, dim=-1) return batch
[docs]class FeatureEncoder(nn.Module): """ Encoding node and edge features Args: dim_in (int): Input feature dimension """ def __init__(self, dim_in): super().__init__() self.dim_in = dim_in if cfg.dataset.node_encoder: # Encode integer node features via nn.Embeddings NodeEncoder = register.node_encoder_dict[ cfg.dataset.node_encoder_name] self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) if cfg.dataset.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode( new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, has_bias=False, cfg=cfg)) # Update dim_in to reflect the new dimension fo the node features self.dim_in = cfg.gnn.dim_inner if cfg.dataset.edge_encoder: # Encode integer edge features via nn.Embeddings EdgeEncoder = register.edge_encoder_dict[ cfg.dataset.edge_encoder_name] self.edge_encoder = EdgeEncoder(cfg.gnn.dim_inner) if cfg.dataset.edge_encoder_bn: self.edge_encoder_bn = BatchNorm1dNode( new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, has_bias=False, cfg=cfg)) def forward(self, batch): """""" for module in self.children(): batch = module(batch) return batch
[docs]class GNN(nn.Module): """ General GNN model: encoder + stage + head Args: dim_in (int): Input dimension dim_out (int): Output dimension **kwargs (optional): Optional additional args """ def __init__(self, dim_in, dim_out, **kwargs): super().__init__() GNNStage = register.stage_dict[cfg.gnn.stage_type] GNNHead = register.head_dict[cfg.gnn.head] self.encoder = FeatureEncoder(dim_in) dim_in = self.encoder.dim_in if cfg.gnn.layers_pre_mp > 0: self.pre_mp = GNNPreMP(dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) dim_in = cfg.gnn.dim_inner if cfg.gnn.layers_mp > 0: self.mp = GNNStage(dim_in=dim_in, dim_out=cfg.gnn.dim_inner, num_layers=cfg.gnn.layers_mp) self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) self.apply(init_weights) def forward(self, batch): """""" for module in self.children(): batch = module(batch) return batch