Source code for torch_geometric.graphgym.models.gnn

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.head import head_dict
from torch_geometric.graphgym.models.layer import (GeneralLayer,
                                                   GeneralMultiLayer,
                                                   BatchNorm1dNode,
                                                   BatchNorm1dEdge)
from torch_geometric.graphgym.init import init_weights
from torch_geometric.graphgym.models.encoder import node_encoder_dict, \
    edge_encoder_dict

import torch_geometric.graphgym.register as register


[docs]def GNNLayer(dim_in, dim_out, has_act=True): """ Wrapper for a GNN layer Args: dim_in (int): Input dimension dim_out (int): Output dimension has_act (bool): Whether has activation function after the layer """ return GeneralLayer(cfg.gnn.layer_type, dim_in, dim_out, has_act)
[docs]def GNNPreMP(dim_in, dim_out): """ Wrapper for NN layer before GNN message passing Args: dim_in (int): Input dimension dim_out (int): Output dimension """ return GeneralMultiLayer('linear', cfg.gnn.layers_pre_mp, dim_in, dim_out, dim_inner=dim_out, final_act=True)
[docs]class GNNStackStage(nn.Module): """ Simple Stage that stack GNN layers Args: dim_in (int): Input dimension dim_out (int): Output dimension num_layers (int): Number of GNN layers """ def __init__(self, dim_in, dim_out, num_layers): super(GNNStackStage, self).__init__() self.num_layers = num_layers for i in range(num_layers): if cfg.gnn.stage_type == 'skipconcat': d_in = dim_in if i == 0 else dim_in + i * dim_out else: d_in = dim_in if i == 0 else dim_out layer = GNNLayer(d_in, dim_out) self.add_module('layer{}'.format(i), layer) def forward(self, batch): """""" for i, layer in enumerate(self.children()): x = batch.x batch = layer(batch) if cfg.gnn.stage_type == 'skipsum': batch.x = x + batch.x elif cfg.gnn.stage_type == 'skipconcat' and \ i < self.num_layers - 1: batch.x = torch.cat([x, batch.x], dim=1) if cfg.gnn.l2norm: batch.x = F.normalize(batch.x, p=2, dim=-1) return batch
stage_dict = { 'stack': GNNStackStage, 'skipsum': GNNStackStage, 'skipconcat': GNNStackStage, } stage_dict = {**register.stage_dict, **stage_dict}
[docs]class FeatureEncoder(nn.Module): """ Encoding node and edge features Args: dim_in (int): Input feature dimension """ def __init__(self, dim_in): super(FeatureEncoder, self).__init__() self.dim_in = dim_in if cfg.dataset.node_encoder: # Encode integer node features via nn.Embeddings NodeEncoder = node_encoder_dict[cfg.dataset.node_encoder_name] self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) if cfg.dataset.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode(cfg.gnn.dim_inner) # Update dim_in to reflect the new dimension fo the node features self.dim_in = cfg.gnn.dim_inner if cfg.dataset.edge_encoder: # Encode integer edge features via nn.Embeddings EdgeEncoder = edge_encoder_dict[cfg.dataset.edge_encoder_name] self.edge_encoder = EdgeEncoder(cfg.gnn.dim_inner) if cfg.dataset.edge_encoder_bn: self.edge_encoder_bn = BatchNorm1dEdge(cfg.gnn.dim_inner) def forward(self, batch): """""" for module in self.children(): batch = module(batch) return batch
[docs]class GNN(nn.Module): """ General GNN model: encoder + stage + head Args: dim_in (int): Input dimension dim_out (int): Output dimension **kwargs (optional): Optional additional args """ def __init__(self, dim_in, dim_out, **kwargs): super(GNN, self).__init__() GNNStage = stage_dict[cfg.gnn.stage_type] GNNHead = head_dict[cfg.gnn.head] self.encoder = FeatureEncoder(dim_in) dim_in = self.encoder.dim_in if cfg.gnn.layers_pre_mp > 0: self.pre_mp = GNNPreMP(dim_in, cfg.gnn.dim_inner) dim_in = cfg.gnn.dim_inner if cfg.gnn.layers_mp > 0: self.mp = GNNStage(dim_in=dim_in, dim_out=cfg.gnn.dim_inner, num_layers=cfg.gnn.layers_mp) self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) self.apply(init_weights) def forward(self, batch): """""" for module in self.children(): batch = module(batch) return batch if cfg.benchmark: # register to measure the forward pass from torch_geometric.graphgym.benchmark import global_line_profiler global_line_profiler.add_function(forward)