""" GNN heads are the last layer of a GNN right before loss computation.
They are constructed in the init function of the gnn.GNN.
"""
import torch
import torch.nn as nn
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import new_layer_config, MLP
import torch_geometric.graphgym.models.pooling # noqa, register module
import torch_geometric.graphgym.register as register
[docs]class GNNNodeHead(nn.Module):
"""
GNN prediction head for node prediction tasks.
Args:
dim_in (int): Input dimension
dim_out (int): Output dimension. For binary prediction, dim_out=1.
"""
def __init__(self, dim_in, dim_out):
super(GNNNodeHead, self).__init__()
self.layer_post_mp = MLP(
new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp,
has_act=False, has_bias=True, cfg=cfg))
def _apply_index(self, batch):
mask = '{}_mask'.format(batch.split)
return batch.x[batch[mask]], \
batch.y[batch[mask]]
def forward(self, batch):
batch = self.layer_post_mp(batch)
pred, label = self._apply_index(batch)
return pred, label
[docs]class GNNEdgeHead(nn.Module):
"""
GNN prediction head for edge/link prediction tasks.
Args:
dim_in (int): Input dimension
dim_out (int): Output dimension. For binary prediction, dim_out=1.
"""
def __init__(self, dim_in, dim_out):
super(GNNEdgeHead, self).__init__()
# module to decode edges from node embeddings
if cfg.model.edge_decoding == 'concat':
self.layer_post_mp = MLP(
new_layer_config(dim_in * 2, dim_out, cfg.gnn.layers_post_mp,
has_act=False, has_bias=True, cfg=cfg))
# requires parameter
self.decode_module = lambda v1, v2: \
self.layer_post_mp(torch.cat((v1, v2), dim=-1))
else:
if dim_out > 1:
raise ValueError(
'Binary edge decoding ({})is used for multi-class '
'edge/link prediction.'.format(cfg.model.edge_decoding))
self.layer_post_mp = MLP(
new_layer_config(dim_in, dim_in, cfg.gnn.layers_post_mp,
has_act=False, has_bias=True, cfg=cfg))
if cfg.model.edge_decoding == 'dot':
self.decode_module = lambda v1, v2: torch.sum(v1 * v2, dim=-1)
elif cfg.model.edge_decoding == 'cosine_similarity':
self.decode_module = nn.CosineSimilarity(dim=-1)
else:
raise ValueError('Unknown edge decoding {}.'.format(
cfg.model.edge_decoding))
def _apply_index(self, batch):
index = '{}_edge_index'.format(batch.split)
label = '{}_edge_label'.format(batch.split)
return batch.x[batch[index]], \
batch[label]
def forward(self, batch):
if cfg.model.edge_decoding != 'concat':
batch = self.layer_post_mp(batch)
pred, label = self._apply_index(batch)
nodes_first = pred[0]
nodes_second = pred[1]
pred = self.decode_module(nodes_first, nodes_second)
return pred, label
[docs]class GNNGraphHead(nn.Module):
"""
GNN prediction head for graph prediction tasks.
The optional post_mp layer (specified by cfg.gnn.post_mp) is used
to transform the pooled embedding using an MLP.
Args:
dim_in (int): Input dimension
dim_out (int): Output dimension. For binary prediction, dim_out=1.
"""
def __init__(self, dim_in, dim_out):
super(GNNGraphHead, self).__init__()
self.layer_post_mp = MLP(
new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp,
has_act=False, has_bias=True, cfg=cfg))
self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling]
def _apply_index(self, batch):
return batch.graph_feature, batch.y
def forward(self, batch):
graph_emb = self.pooling_fun(batch.x, batch.batch)
graph_emb = self.layer_post_mp(graph_emb)
batch.graph_feature = graph_emb
pred, label = self._apply_index(batch)
return pred, label
# Head models for external interface
head_dict = {
'node': GNNNodeHead,
'edge': GNNEdgeHead,
'link_pred': GNNEdgeHead,
'graph': GNNGraphHead
}
register.head_dict = {**register.head_dict, **head_dict}