import copy
from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear, ModuleList
from torch_geometric.nn.conv import (GATConv, GCNConv, GINConv, MessagePassing,
PNAConv, SAGEConv)
from torch_geometric.nn.models import MLP
from torch_geometric.nn.models.jumping_knowledge import JumpingKnowledge
from torch_geometric.typing import Adj
class BasicGNN(torch.nn.Module):
r"""An abstract class for implementing basic GNN models.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
norm (torch.nn.Module, optional): The normalization operator to use.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
node embeddings to the expected output feature dimensionality.
(:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
:obj:`"lstm"`). (default: :obj:`None`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of the underlying
:class:`torch_geometric.nn.conv.MessagePassing` layers.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
num_layers: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
act: Union[str, Callable, None] = "relu",
norm: Optional[torch.nn.Module] = None,
jk: Optional[str] = None,
act_first: bool = False,
act_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__()
from class_resolver.contrib.torch import activation_resolver
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.num_layers = num_layers
self.dropout = dropout
self.act = activation_resolver.make(act, act_kwargs)
self.jk_mode = jk
self.act_first = act_first
if out_channels is not None:
self.out_channels = out_channels
else:
self.out_channels = hidden_channels
self.convs = ModuleList()
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))
for _ in range(num_layers - 2):
self.convs.append(
self.init_conv(hidden_channels, hidden_channels, **kwargs))
if out_channels is not None and jk is None:
self.convs.append(
self.init_conv(hidden_channels, out_channels, **kwargs))
else:
self.convs.append(
self.init_conv(hidden_channels, hidden_channels, **kwargs))
self.norms = None
if norm is not None:
self.norms = ModuleList()
for _ in range(num_layers - 1):
self.norms.append(copy.deepcopy(norm))
if jk is not None:
self.norms.append(copy.deepcopy(norm))
if jk is not None and jk != 'last':
self.jk = JumpingKnowledge(jk, hidden_channels, num_layers)
if jk is not None:
if jk == 'cat':
in_channels = num_layers * hidden_channels
else:
in_channels = hidden_channels
self.lin = Linear(in_channels, self.out_channels)
def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
raise NotImplementedError
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for norm in self.norms or []:
norm.reset_parameters()
if hasattr(self, 'jk'):
self.jk.reset_parameters()
if hasattr(self, 'lin'):
self.lin.reset_parameters()
def forward(self, x: Tensor, edge_index: Adj, *args, **kwargs) -> Tensor:
""""""
xs: List[Tensor] = []
for i in range(self.num_layers):
x = self.convs[i](x, edge_index, *args, **kwargs)
if i == self.num_layers - 1 and self.jk_mode is None:
break
if self.act_first:
x = self.act(x)
if self.norms is not None:
x = self.norms[i](x)
if not self.act_first:
x = self.act(x)
x = F.dropout(x, p=self.dropout, training=self.training)
if hasattr(self, 'jk'):
xs.append(x)
x = self.jk(xs) if hasattr(self, 'jk') else x
x = self.lin(x) if hasattr(self, 'lin') else x
return x
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, num_layers={self.num_layers})')
[docs]class GCN(BasicGNN):
r"""The Graph Neural Network from the `"Semi-supervised
Classification with Graph Convolutional Networks"
<https://arxiv.org/abs/1609.02907>`_ paper, using the
:class:`~torch_geometric.nn.conv.GCNConv` operator for message passing.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
norm (torch.nn.Module, optional): The normalization operator to use.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode
(:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`).
(default: :obj:`"last"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GCNConv`.
"""
def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
return GCNConv(in_channels, out_channels, **kwargs)
[docs]class GraphSAGE(BasicGNN):
r"""The Graph Neural Network from the `"Inductive Representation Learning
on Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper, using the
:class:`~torch_geometric.nn.SAGEConv` operator for message passing.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
norm (torch.nn.Module, optional): The normalization operator to use.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode
(:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`).
(default: :obj:`"last"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.SAGEConv`.
"""
def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
return SAGEConv(in_channels, out_channels, **kwargs)
[docs]class GIN(BasicGNN):
r"""The Graph Neural Network from the `"How Powerful are Graph Neural
Networks?" <https://arxiv.org/abs/1810.00826>`_ paper, using the
:class:`~torch_geometric.nn.GINConv` operator for message passing.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (Callable, optional): The non-linear activation function to use.
(default: :obj:`torch.nn.ReLU(inplace=True)`)
norm (torch.nn.Module, optional): The normalization operator to use.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode
(:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`).
(default: :obj:`"last"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GINConv`.
"""
def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
mlp = MLP([in_channels, out_channels, out_channels], batch_norm=True)
return GINConv(mlp, **kwargs)
[docs]class GAT(BasicGNN):
r"""The Graph Neural Network from the `"Graph Attention Networks"
<https://arxiv.org/abs/1710.10903>`_ paper, using the
:class:`~torch_geometric.nn.GATConv` operator for message passing.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
norm (torch.nn.Module, optional): The normalization operator to use.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode
(:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`).
(default: :obj:`"last"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GATConv`.
"""
def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
kwargs = copy.copy(kwargs)
if 'heads' in kwargs and out_channels % kwargs['heads'] != 0:
kwargs['heads'] = 1
if 'concat' not in kwargs or kwargs['concat']:
out_channels = out_channels // kwargs.get('heads', 1)
return GATConv(in_channels, out_channels, dropout=self.dropout,
**kwargs)
[docs]class PNA(BasicGNN):
r"""The Graph Neural Network from the `"Principal Neighbourhood Aggregation
for Graph Nets" <https://arxiv.org/abs/2004.05718>`_ paper, using the
:class:`~torch_geometric.nn.conv.PNAConv` operator for message passing.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
norm (torch.nn.Module, optional): The normalization operator to use.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode
(:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`).
(default: :obj:`"last"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.PNAConv`.
"""
def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
return PNAConv(in_channels, out_channels, **kwargs)
__all__ = ['GCN', 'GraphSAGE', 'GIN', 'GAT', 'PNA']