Source code for torch_geometric.graphgym.models.layer

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.act import act_dict
from torch_geometric.graphgym.contrib.layer.generalconv import (
    GeneralConvLayer, GeneralEdgeConvLayer)

import torch_geometric.graphgym.register as register

# General classes
[docs]class GeneralLayer(nn.Module): """ General wrapper for layers Args: name (string): Name of the layer in registered :obj:`layer_dict` dim_in (int): Input dimension dim_out (int): Output dimension has_act (bool): Whether has activation after the layer has_bn (bool): Whether has BatchNorm in the layer has_l2norm (bool): Wheter has L2 normalization after the layer **kwargs (optional): Additional args """ def __init__(self, name, dim_in, dim_out, has_act=True, has_bn=True, has_l2norm=False, **kwargs): super(GeneralLayer, self).__init__() self.has_l2norm = has_l2norm has_bn = has_bn and cfg.gnn.batchnorm self.layer = layer_dict[name](dim_in, dim_out, bias=not has_bn, **kwargs) layer_wrapper = [] if has_bn: layer_wrapper.append( nn.BatchNorm1d(dim_out,, if cfg.gnn.dropout > 0: layer_wrapper.append( nn.Dropout(p=cfg.gnn.dropout, inplace=cfg.mem.inplace)) if has_act: layer_wrapper.append(act_dict[cfg.gnn.act]) self.post_layer = nn.Sequential(*layer_wrapper) def forward(self, batch): batch = self.layer(batch) if isinstance(batch, torch.Tensor): batch = self.post_layer(batch) if self.has_l2norm: batch = F.normalize(batch, p=2, dim=1) else: batch.x = self.post_layer(batch.x) if self.has_l2norm: batch.x = F.normalize(batch.x, p=2, dim=1) return batch
[docs]class GeneralMultiLayer(nn.Module): """ General wrapper for a stack of multiple layers Args: name (string): Name of the layer in registered :obj:`layer_dict` num_layers (int): Number of layers in the stack dim_in (int): Input dimension dim_out (int): Output dimension dim_inner (int): The dimension for the inner layers final_act (bool): Whether has activation after the layer stack **kwargs (optional): Additional args """ def __init__(self, name, num_layers, dim_in, dim_out, dim_inner=None, final_act=True, **kwargs): super(GeneralMultiLayer, self).__init__() dim_inner = dim_in if dim_inner is None else dim_inner for i in range(num_layers): d_in = dim_in if i == 0 else dim_inner d_out = dim_out if i == num_layers - 1 else dim_inner has_act = final_act if i == num_layers - 1 else True layer = GeneralLayer(name, d_in, d_out, has_act, **kwargs) self.add_module('Layer_{}'.format(i), layer) def forward(self, batch): for layer in self.children(): batch = layer(batch) return batch
# ---------- Core basic layers. Input: batch; Output: batch ----------------- #
[docs]class Linear(nn.Module): """ Basic Linear layer. Args: dim_in (int): Input dimension dim_out (int): Output dimension bias (bool): Whether has bias term **kwargs (optional): Additional args """ def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(Linear, self).__init__() self.model = nn.Linear(dim_in, dim_out, bias=bias) def forward(self, batch): if isinstance(batch, torch.Tensor): batch = self.model(batch) else: batch.x = self.model(batch.x) return batch
[docs]class BatchNorm1dNode(nn.Module): """ BatchNorm for node feature. Args: dim_in (int): Input dimension """ def __init__(self, dim_in): super(BatchNorm1dNode, self).__init__() = nn.BatchNorm1d(dim_in,, def forward(self, batch): batch.x = return batch
[docs]class BatchNorm1dEdge(nn.Module): """ BatchNorm for edge feature. Args: dim_in (int): Input dimension """ def __init__(self, dim_in): super(BatchNorm1dEdge, self).__init__() = nn.BatchNorm1d(dim_in,, def forward(self, batch): batch.edge_attr = return batch
[docs]class MLP(nn.Module): """ Basic MLP model. Here 1-layer MLP is equivalent to a Liner layer. Args: dim_in (int): Input dimension dim_out (int): Output dimension bias (bool): Whether has bias term dim_inner (int): The dimension for the inner layers num_layers (int): Number of layers in the stack **kwargs (optional): Additional args """ def __init__(self, dim_in, dim_out, bias=True, dim_inner=None, num_layers=2, **kwargs): super(MLP, self).__init__() dim_inner = dim_in if dim_inner is None else dim_inner layers = [] if num_layers > 1: layers.append( GeneralMultiLayer('linear', num_layers - 1, dim_in, dim_inner, dim_inner, final_act=True)) layers.append(Linear(dim_inner, dim_out, bias)) else: layers.append(Linear(dim_in, dim_out, bias)) self.model = nn.Sequential(*layers) def forward(self, batch): if isinstance(batch, torch.Tensor): batch = self.model(batch) else: batch.x = self.model(batch.x) return batch
[docs]class GCNConv(nn.Module): """ Graph Convolutional Network (GCN) layer """ def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(GCNConv, self).__init__() self.model = pyg.nn.GCNConv(dim_in, dim_out, bias=bias) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch
[docs]class SAGEConv(nn.Module): """ GraphSAGE Conv layer """ def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(SAGEConv, self).__init__() self.model = pyg.nn.SAGEConv(dim_in, dim_out, bias=bias) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch
[docs]class GATConv(nn.Module): """ Graph Attention Network (GAT) layer """ def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(GATConv, self).__init__() self.model = pyg.nn.GATConv(dim_in, dim_out, bias=bias) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch
[docs]class GINConv(nn.Module): """ Graph Isomorphism Network (GIN) layer """ def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(GINConv, self).__init__() gin_nn = nn.Sequential(nn.Linear(dim_in, dim_out), nn.ReLU(), nn.Linear(dim_out, dim_out)) self.model = pyg.nn.GINConv(gin_nn) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch
[docs]class SplineConv(nn.Module): """ SplineCNN layer """ def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(SplineConv, self).__init__() self.model = pyg.nn.SplineConv(dim_in, dim_out, dim=1, kernel_size=2, bias=bias) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index, batch.edge_attr) return batch
[docs]class GeneralConv(nn.Module): """A general GNN layer""" def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(GeneralConv, self).__init__() self.model = GeneralConvLayer(dim_in, dim_out, bias=bias) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index) return batch
[docs]class GeneralEdgeConv(nn.Module): """A general GNN layer that supports edge features as well""" def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(GeneralEdgeConv, self).__init__() self.model = GeneralEdgeConvLayer(dim_in, dim_out, bias=bias) def forward(self, batch): batch.x = self.model(batch.x, batch.edge_index, edge_feature=batch.edge_attr) return batch
[docs]class GeneralSampleEdgeConv(nn.Module): """A general GNN layer that supports edge features and edge sampling""" def __init__(self, dim_in, dim_out, bias=False, **kwargs): super(GeneralSampleEdgeConv, self).__init__() self.model = GeneralEdgeConvLayer(dim_in, dim_out, bias=bias) def forward(self, batch): edge_mask = torch.rand(batch.edge_index.shape[1]) < cfg.gnn.keep_edge edge_index = batch.edge_index[:, edge_mask] edge_feature = batch.edge_attr[edge_mask, :] batch.x = self.model(batch.x, edge_index, edge_feature=edge_feature) return batch
layer_dict = { 'linear': Linear, 'mlp': MLP, 'gcnconv': GCNConv, 'sageconv': SAGEConv, 'gatconv': GATConv, 'splineconv': SplineConv, 'ginconv': GINConv, 'generalconv': GeneralConv, 'generaledgeconv': GeneralEdgeConv, 'generalsampleedgeconv': GeneralSampleEdgeConv, } # register additional convs layer_dict = {**register.layer_dict, **layer_dict}