import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.act import act_dict
from torch_geometric.graphgym.contrib.layer.generalconv import (
GeneralConvLayer, GeneralEdgeConvLayer)
import torch_geometric.graphgym.register as register
# General classes
[docs]class GeneralLayer(nn.Module):
"""
General wrapper for layers
Args:
name (string): Name of the layer in registered :obj:`layer_dict`
dim_in (int): Input dimension
dim_out (int): Output dimension
has_act (bool): Whether has activation after the layer
has_bn (bool): Whether has BatchNorm in the layer
has_l2norm (bool): Wheter has L2 normalization after the layer
**kwargs (optional): Additional args
"""
def __init__(self, name, dim_in, dim_out, has_act=True, has_bn=True,
has_l2norm=False, **kwargs):
super(GeneralLayer, self).__init__()
self.has_l2norm = has_l2norm
has_bn = has_bn and cfg.gnn.batchnorm
self.layer = layer_dict[name](dim_in, dim_out, bias=not has_bn,
**kwargs)
layer_wrapper = []
if has_bn:
layer_wrapper.append(
nn.BatchNorm1d(dim_out, eps=cfg.bn.eps, momentum=cfg.bn.mom))
if cfg.gnn.dropout > 0:
layer_wrapper.append(
nn.Dropout(p=cfg.gnn.dropout, inplace=cfg.mem.inplace))
if has_act:
layer_wrapper.append(act_dict[cfg.gnn.act])
self.post_layer = nn.Sequential(*layer_wrapper)
def forward(self, batch):
batch = self.layer(batch)
if isinstance(batch, torch.Tensor):
batch = self.post_layer(batch)
if self.has_l2norm:
batch = F.normalize(batch, p=2, dim=1)
else:
batch.x = self.post_layer(batch.x)
if self.has_l2norm:
batch.x = F.normalize(batch.x, p=2, dim=1)
return batch
[docs]class GeneralMultiLayer(nn.Module):
"""
General wrapper for a stack of multiple layers
Args:
name (string): Name of the layer in registered :obj:`layer_dict`
num_layers (int): Number of layers in the stack
dim_in (int): Input dimension
dim_out (int): Output dimension
dim_inner (int): The dimension for the inner layers
final_act (bool): Whether has activation after the layer stack
**kwargs (optional): Additional args
"""
def __init__(self, name, num_layers, dim_in, dim_out, dim_inner=None,
final_act=True, **kwargs):
super(GeneralMultiLayer, self).__init__()
dim_inner = dim_in if dim_inner is None else dim_inner
for i in range(num_layers):
d_in = dim_in if i == 0 else dim_inner
d_out = dim_out if i == num_layers - 1 else dim_inner
has_act = final_act if i == num_layers - 1 else True
layer = GeneralLayer(name, d_in, d_out, has_act, **kwargs)
self.add_module('Layer_{}'.format(i), layer)
def forward(self, batch):
for layer in self.children():
batch = layer(batch)
return batch
# ---------- Core basic layers. Input: batch; Output: batch ----------------- #
[docs]class Linear(nn.Module):
"""
Basic Linear layer.
Args:
dim_in (int): Input dimension
dim_out (int): Output dimension
bias (bool): Whether has bias term
**kwargs (optional): Additional args
"""
def __init__(self, dim_in, dim_out, bias=False, **kwargs):
super(Linear, self).__init__()
self.model = nn.Linear(dim_in, dim_out, bias=bias)
def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = self.model(batch)
else:
batch.x = self.model(batch.x)
return batch
[docs]class BatchNorm1dNode(nn.Module):
"""
BatchNorm for node feature.
Args:
dim_in (int): Input dimension
"""
def __init__(self, dim_in):
super(BatchNorm1dNode, self).__init__()
self.bn = nn.BatchNorm1d(dim_in, eps=cfg.bn.eps, momentum=cfg.bn.mom)
def forward(self, batch):
batch.x = self.bn(batch.x)
return batch
[docs]class BatchNorm1dEdge(nn.Module):
"""
BatchNorm for edge feature.
Args:
dim_in (int): Input dimension
"""
def __init__(self, dim_in):
super(BatchNorm1dEdge, self).__init__()
self.bn = nn.BatchNorm1d(dim_in, eps=cfg.bn.eps, momentum=cfg.bn.mom)
def forward(self, batch):
batch.edge_attr = self.bn(batch.edge_attr)
return batch
[docs]class MLP(nn.Module):
"""
Basic MLP model.
Here 1-layer MLP is equivalent to a Liner layer.
Args:
dim_in (int): Input dimension
dim_out (int): Output dimension
bias (bool): Whether has bias term
dim_inner (int): The dimension for the inner layers
num_layers (int): Number of layers in the stack
**kwargs (optional): Additional args
"""
def __init__(self, dim_in, dim_out, bias=True, dim_inner=None,
num_layers=2, **kwargs):
super(MLP, self).__init__()
dim_inner = dim_in if dim_inner is None else dim_inner
layers = []
if num_layers > 1:
layers.append(
GeneralMultiLayer('linear', num_layers - 1, dim_in, dim_inner,
dim_inner, final_act=True))
layers.append(Linear(dim_inner, dim_out, bias))
else:
layers.append(Linear(dim_in, dim_out, bias))
self.model = nn.Sequential(*layers)
def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = self.model(batch)
else:
batch.x = self.model(batch.x)
return batch
[docs]class GCNConv(nn.Module):
"""
Graph Convolutional Network (GCN) layer
"""
def __init__(self, dim_in, dim_out, bias=False, **kwargs):
super(GCNConv, self).__init__()
self.model = pyg.nn.GCNConv(dim_in, dim_out, bias=bias)
def forward(self, batch):
batch.x = self.model(batch.x, batch.edge_index)
return batch
[docs]class SAGEConv(nn.Module):
"""
GraphSAGE Conv layer
"""
def __init__(self, dim_in, dim_out, bias=False, **kwargs):
super(SAGEConv, self).__init__()
self.model = pyg.nn.SAGEConv(dim_in, dim_out, bias=bias)
def forward(self, batch):
batch.x = self.model(batch.x, batch.edge_index)
return batch
[docs]class GATConv(nn.Module):
"""
Graph Attention Network (GAT) layer
"""
def __init__(self, dim_in, dim_out, bias=False, **kwargs):
super(GATConv, self).__init__()
self.model = pyg.nn.GATConv(dim_in, dim_out, bias=bias)
def forward(self, batch):
batch.x = self.model(batch.x, batch.edge_index)
return batch
[docs]class GINConv(nn.Module):
"""
Graph Isomorphism Network (GIN) layer
"""
def __init__(self, dim_in, dim_out, bias=False, **kwargs):
super(GINConv, self).__init__()
gin_nn = nn.Sequential(nn.Linear(dim_in, dim_out), nn.ReLU(),
nn.Linear(dim_out, dim_out))
self.model = pyg.nn.GINConv(gin_nn)
def forward(self, batch):
batch.x = self.model(batch.x, batch.edge_index)
return batch
[docs]class SplineConv(nn.Module):
"""
SplineCNN layer
"""
def __init__(self, dim_in, dim_out, bias=False, **kwargs):
super(SplineConv, self).__init__()
self.model = pyg.nn.SplineConv(dim_in, dim_out, dim=1, kernel_size=2,
bias=bias)
def forward(self, batch):
batch.x = self.model(batch.x, batch.edge_index, batch.edge_attr)
return batch
[docs]class GeneralConv(nn.Module):
"""A general GNN layer"""
def __init__(self, dim_in, dim_out, bias=False, **kwargs):
super(GeneralConv, self).__init__()
self.model = GeneralConvLayer(dim_in, dim_out, bias=bias)
def forward(self, batch):
batch.x = self.model(batch.x, batch.edge_index)
return batch
[docs]class GeneralEdgeConv(nn.Module):
"""A general GNN layer that supports edge features as well"""
def __init__(self, dim_in, dim_out, bias=False, **kwargs):
super(GeneralEdgeConv, self).__init__()
self.model = GeneralEdgeConvLayer(dim_in, dim_out, bias=bias)
def forward(self, batch):
batch.x = self.model(batch.x, batch.edge_index,
edge_feature=batch.edge_attr)
return batch
[docs]class GeneralSampleEdgeConv(nn.Module):
"""A general GNN layer that supports edge features and edge sampling"""
def __init__(self, dim_in, dim_out, bias=False, **kwargs):
super(GeneralSampleEdgeConv, self).__init__()
self.model = GeneralEdgeConvLayer(dim_in, dim_out, bias=bias)
def forward(self, batch):
edge_mask = torch.rand(batch.edge_index.shape[1]) < cfg.gnn.keep_edge
edge_index = batch.edge_index[:, edge_mask]
edge_feature = batch.edge_attr[edge_mask, :]
batch.x = self.model(batch.x, edge_index, edge_feature=edge_feature)
return batch
layer_dict = {
'linear': Linear,
'mlp': MLP,
'gcnconv': GCNConv,
'sageconv': SAGEConv,
'gatconv': GATConv,
'splineconv': SplineConv,
'ginconv': GINConv,
'generalconv': GeneralConv,
'generaledgeconv': GeneralEdgeConv,
'generalsampleedgeconv': GeneralSampleEdgeConv,
}
# register additional convs
layer_dict = {**register.layer_dict, **layer_dict}