import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.init import init_weights
from torch_geometric.graphgym.models.layer import (
BatchNorm1dNode,
GeneralLayer,
GeneralMultiLayer,
new_layer_config,
)
from torch_geometric.graphgym.register import register_stage
[docs]def GNNLayer(dim_in, dim_out, has_act=True):
"""
Wrapper for a GNN layer
Args:
dim_in (int): Input dimension
dim_out (int): Output dimension
has_act (bool): Whether has activation function after the layer
"""
return GeneralLayer(
cfg.gnn.layer_type,
layer_config=new_layer_config(dim_in, dim_out, 1, has_act=has_act,
has_bias=False, cfg=cfg))
[docs]def GNNPreMP(dim_in, dim_out, num_layers):
"""
Wrapper for NN layer before GNN message passing
Args:
dim_in (int): Input dimension
dim_out (int): Output dimension
num_layers (int): Number of layers
"""
return GeneralMultiLayer(
'linear',
layer_config=new_layer_config(dim_in, dim_out, num_layers,
has_act=False, has_bias=False, cfg=cfg))
[docs]@register_stage('stack')
@register_stage('skipsum')
@register_stage('skipconcat')
class GNNStackStage(nn.Module):
"""
Simple Stage that stack GNN layers
Args:
dim_in (int): Input dimension
dim_out (int): Output dimension
num_layers (int): Number of GNN layers
"""
def __init__(self, dim_in, dim_out, num_layers):
super().__init__()
self.num_layers = num_layers
for i in range(num_layers):
if cfg.gnn.stage_type == 'skipconcat':
d_in = dim_in if i == 0 else dim_in + i * dim_out
else:
d_in = dim_in if i == 0 else dim_out
layer = GNNLayer(d_in, dim_out)
self.add_module('layer{}'.format(i), layer)
def forward(self, batch):
""""""
for i, layer in enumerate(self.children()):
x = batch.x
batch = layer(batch)
if cfg.gnn.stage_type == 'skipsum':
batch.x = x + batch.x
elif cfg.gnn.stage_type == 'skipconcat' and \
i < self.num_layers - 1:
batch.x = torch.cat([x, batch.x], dim=1)
if cfg.gnn.l2norm:
batch.x = F.normalize(batch.x, p=2, dim=-1)
return batch
[docs]class FeatureEncoder(nn.Module):
"""
Encoding node and edge features
Args:
dim_in (int): Input feature dimension
"""
def __init__(self, dim_in):
super().__init__()
self.dim_in = dim_in
if cfg.dataset.node_encoder:
# Encode integer node features via nn.Embeddings
NodeEncoder = register.node_encoder_dict[
cfg.dataset.node_encoder_name]
self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)
if cfg.dataset.node_encoder_bn:
self.node_encoder_bn = BatchNorm1dNode(
new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False,
has_bias=False, cfg=cfg))
# Update dim_in to reflect the new dimension fo the node features
self.dim_in = cfg.gnn.dim_inner
if cfg.dataset.edge_encoder:
# Encode integer edge features via nn.Embeddings
EdgeEncoder = register.edge_encoder_dict[
cfg.dataset.edge_encoder_name]
self.edge_encoder = EdgeEncoder(cfg.gnn.dim_inner)
if cfg.dataset.edge_encoder_bn:
self.edge_encoder_bn = BatchNorm1dNode(
new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False,
has_bias=False, cfg=cfg))
def forward(self, batch):
""""""
for module in self.children():
batch = module(batch)
return batch
[docs]class GNN(nn.Module):
"""
General GNN model: encoder + stage + head
Args:
dim_in (int): Input dimension
dim_out (int): Output dimension
**kwargs (optional): Optional additional args
"""
def __init__(self, dim_in, dim_out, **kwargs):
super().__init__()
GNNStage = register.stage_dict[cfg.gnn.stage_type]
GNNHead = register.head_dict[cfg.gnn.head]
self.encoder = FeatureEncoder(dim_in)
dim_in = self.encoder.dim_in
if cfg.gnn.layers_pre_mp > 0:
self.pre_mp = GNNPreMP(dim_in, cfg.gnn.dim_inner,
cfg.gnn.layers_pre_mp)
dim_in = cfg.gnn.dim_inner
if cfg.gnn.layers_mp > 0:
self.mp = GNNStage(dim_in=dim_in, dim_out=cfg.gnn.dim_inner,
num_layers=cfg.gnn.layers_mp)
self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)
self.apply(init_weights)
def forward(self, batch):
""""""
for module in self.children():
batch = module(batch)
return batch