import torch
import torch.nn.functional as F
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.init import init_weights
from torch_geometric.graphgym.models.layer import (
BatchNorm1dNode,
GeneralLayer,
GeneralMultiLayer,
new_layer_config,
)
from torch_geometric.graphgym.register import register_stage
[docs]def GNNLayer(dim_in: int, dim_out: int, has_act: bool = True) -> GeneralLayer:
r"""Creates a GNN layer, given the specified input and output dimensions
and the underlying configuration in :obj:`cfg`.
Args:
dim_in (int): The input dimension
dim_out (int): The output dimension.
has_act (bool, optional): Whether to apply an activation function
after the layer. (default: :obj:`True`)
"""
return GeneralLayer(
cfg.gnn.layer_type,
layer_config=new_layer_config(
dim_in,
dim_out,
1,
has_act=has_act,
has_bias=False,
cfg=cfg,
),
)
[docs]def GNNPreMP(dim_in: int, dim_out: int, num_layers: int) -> GeneralMultiLayer:
r"""Creates a NN layer used before message passing, given the specified
input and output dimensions and the underlying configuration in :obj:`cfg`.
Args:
dim_in (int): The input dimension
dim_out (int): The output dimension.
num_layers (int): The number of layers.
"""
return GeneralMultiLayer(
'linear',
layer_config=new_layer_config(
dim_in,
dim_out,
num_layers,
has_act=False,
has_bias=False,
cfg=cfg,
),
)
[docs]@register_stage('stack')
@register_stage('skipsum')
@register_stage('skipconcat')
class GNNStackStage(torch.nn.Module):
r"""Stacks a number of GNN layers.
Args:
dim_in (int): The input dimension
dim_out (int): The output dimension.
num_layers (int): The number of layers.
"""
def __init__(self, dim_in, dim_out, num_layers):
super().__init__()
self.num_layers = num_layers
for i in range(num_layers):
if cfg.gnn.stage_type == 'skipconcat':
d_in = dim_in if i == 0 else dim_in + i * dim_out
else:
d_in = dim_in if i == 0 else dim_out
layer = GNNLayer(d_in, dim_out)
self.add_module(f'layer{i}', layer)
def forward(self, batch):
for i, layer in enumerate(self.children()):
x = batch.x
batch = layer(batch)
if cfg.gnn.stage_type == 'skipsum':
batch.x = x + batch.x
elif (cfg.gnn.stage_type == 'skipconcat'
and i < self.num_layers - 1):
batch.x = torch.cat([x, batch.x], dim=1)
if cfg.gnn.l2norm:
batch.x = F.normalize(batch.x, p=2, dim=-1)
return batch
[docs]class FeatureEncoder(torch.nn.Module):
r"""Encodes node and edge features, given the specified input dimension and
the underlying configuration in :obj:`cfg`.
Args:
dim_in (int): The input feature dimension.
"""
def __init__(self, dim_in: int):
super().__init__()
self.dim_in = dim_in
if cfg.dataset.node_encoder:
# Encode integer node features via `torch.nn.Embedding`:
NodeEncoder = register.node_encoder_dict[
cfg.dataset.node_encoder_name]
self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)
if cfg.dataset.node_encoder_bn:
self.node_encoder_bn = BatchNorm1dNode(
new_layer_config(
cfg.gnn.dim_inner,
-1,
-1,
has_act=False,
has_bias=False,
cfg=cfg,
))
# Update `dim_in` to reflect the new dimension fo the node features
self.dim_in = cfg.gnn.dim_inner
if cfg.dataset.edge_encoder:
# Encode integer edge features via `torch.nn.Embedding`:
EdgeEncoder = register.edge_encoder_dict[
cfg.dataset.edge_encoder_name]
self.edge_encoder = EdgeEncoder(cfg.gnn.dim_inner)
if cfg.dataset.edge_encoder_bn:
self.edge_encoder_bn = BatchNorm1dNode(
new_layer_config(
cfg.gnn.dim_inner,
-1,
-1,
has_act=False,
has_bias=False,
cfg=cfg,
))
def forward(self, batch):
for module in self.children():
batch = module(batch)
return batch
[docs]class GNN(torch.nn.Module):
r"""A general Graph Neural Network (GNN) model.
The GNN model consists of three main components:
1. An encoder to transform input features into a fixed-size embedding
space.
2. A processing or message passing stage for information exchange between
nodes.
3. A head to produce the final output features/predictions.
The configuration of each component is determined by the underlying
configuration in :obj:`cfg`.
Args:
dim_in (int): The input feature dimension.
dim_out (int): The output feature dimension.
**kwargs (optional): Additional keyword arguments.
"""
def __init__(self, dim_in: int, dim_out: int, **kwargs):
super().__init__()
GNNStage = register.stage_dict[cfg.gnn.stage_type]
GNNHead = register.head_dict[cfg.gnn.head]
self.encoder = FeatureEncoder(dim_in)
dim_in = self.encoder.dim_in
if cfg.gnn.layers_pre_mp > 0:
self.pre_mp = GNNPreMP(dim_in, cfg.gnn.dim_inner,
cfg.gnn.layers_pre_mp)
dim_in = cfg.gnn.dim_inner
if cfg.gnn.layers_mp > 0:
self.mp = GNNStage(dim_in=dim_in, dim_out=cfg.gnn.dim_inner,
num_layers=cfg.gnn.layers_mp)
self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)
self.apply(init_weights)
def forward(self, batch):
for module in self.children():
batch = module(batch)
return batch