import copy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear, ModuleList
from tqdm import tqdm
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn.conv import (
EdgeConv,
GATConv,
GATv2Conv,
GCNConv,
GINConv,
MessagePassing,
PNAConv,
SAGEConv,
)
from torch_geometric.nn.models import MLP
from torch_geometric.nn.models.jumping_knowledge import JumpingKnowledge
from torch_geometric.nn.resolver import (
activation_resolver,
normalization_resolver,
)
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils.trim_to_layer import TrimToLayer
class BasicGNN(torch.nn.Module):
r"""An abstract class for implementing basic GNN models.
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
node embeddings to the expected output feature dimensionality.
(:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
:obj:`"lstm"`). (default: :obj:`None`)
**kwargs (optional): Additional arguments of the underlying
:class:`torch_geometric.nn.conv.MessagePassing` layers.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
num_layers: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
act: Union[str, Callable, None] = "relu",
act_first: bool = False,
act_kwargs: Optional[Dict[str, Any]] = None,
norm: Union[str, Callable, None] = None,
norm_kwargs: Optional[Dict[str, Any]] = None,
jk: Optional[str] = None,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.num_layers = num_layers
self.dropout = dropout
self.act = activation_resolver(act, **(act_kwargs or {}))
self.jk_mode = jk
self.act_first = act_first
self.norm = norm if isinstance(norm, str) else None
self.norm_kwargs = norm_kwargs
if out_channels is not None:
self.out_channels = out_channels
else:
self.out_channels = hidden_channels
self.convs = ModuleList()
if num_layers > 1:
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))
if isinstance(in_channels, (tuple, list)):
in_channels = (hidden_channels, hidden_channels)
else:
in_channels = hidden_channels
for _ in range(num_layers - 2):
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))
if isinstance(in_channels, (tuple, list)):
in_channels = (hidden_channels, hidden_channels)
else:
in_channels = hidden_channels
if out_channels is not None and jk is None:
self._is_conv_to_out = True
self.convs.append(
self.init_conv(in_channels, out_channels, **kwargs))
else:
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))
self.norms = None
if norm is not None:
norm_layer = normalization_resolver(
norm,
hidden_channels,
**(norm_kwargs or {}),
)
self.norms = ModuleList()
for _ in range(num_layers - 1):
self.norms.append(copy.deepcopy(norm_layer))
if jk is not None:
self.norms.append(copy.deepcopy(norm_layer))
if jk is not None and jk != 'last':
self.jk = JumpingKnowledge(jk, hidden_channels, num_layers)
if jk is not None:
if jk == 'cat':
in_channels = num_layers * hidden_channels
else:
in_channels = hidden_channels
self.lin = Linear(in_channels, self.out_channels)
# We define `trim_to_layer` functionality as a module such that we can
# still use `to_hetero` on-top.
self._trim = TrimToLayer()
def init_conv(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, **kwargs) -> MessagePassing:
raise NotImplementedError
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
for conv in self.convs:
conv.reset_parameters()
for norm in self.norms or []:
norm.reset_parameters()
if hasattr(self, 'jk'):
self.jk.reset_parameters()
if hasattr(self, 'lin'):
self.lin.reset_parameters()
def forward(
self,
x: Tensor,
edge_index: Adj,
*,
edge_weight: OptTensor = None,
edge_attr: OptTensor = None,
num_sampled_nodes_per_hop: Optional[List[int]] = None,
num_sampled_edges_per_hop: Optional[List[int]] = None,
) -> Tensor:
r"""
Args:
x (torch.Tensor): The input node features.
edge_index (torch.Tensor): The edge indices.
edge_weight (torch.Tensor, optional): The edge weights (if
supported by the underlying GNN layer). (default: :obj:`None`)
edge_attr (torch.Tensor, optional): The edge features (if supported
by the underlying GNN layer). (default: :obj:`None`)
num_sampled_nodes_per_hop (List[int], optional): The number of
sampled nodes per hop.
Useful in :class:~torch_geometric.loader.NeighborLoader`
scenarios to only operate on minimal-sized representations.
(default: :obj:`None`)
num_sampled_edges_per_hop (List[int], optional): The number of
sampled edges per hop.
Useful in :class:~torch_geometric.loader.NeighborLoader`
scenarios to only operate on minimal-sized representations.
(default: :obj:`None`)
"""
if (num_sampled_nodes_per_hop is not None
and isinstance(edge_weight, Tensor)
and isinstance(edge_attr, Tensor)):
raise NotImplementedError("'trim_to_layer' functionality does not "
"yet support trimming of both "
"'edge_weight' and 'edge_attr'")
xs: List[Tensor] = []
for i in range(self.num_layers):
if num_sampled_nodes_per_hop is not None:
x, edge_index, value = self._trim(
i,
num_sampled_nodes_per_hop,
num_sampled_edges_per_hop,
x,
edge_index,
edge_weight if edge_weight is not None else edge_attr,
)
if edge_weight is not None:
edge_weight = value
else:
edge_attr = value
# Tracing the module is not allowed with *args and **kwargs :(
# As such, we rely on a static solution to pass optional edge
# weights and edge attributes to the module.
if self.supports_edge_weight and self.supports_edge_attr:
x = self.convs[i](x, edge_index, edge_weight=edge_weight,
edge_attr=edge_attr)
elif self.supports_edge_weight:
x = self.convs[i](x, edge_index, edge_weight=edge_weight)
elif self.supports_edge_attr:
x = self.convs[i](x, edge_index, edge_attr=edge_attr)
else:
x = self.convs[i](x, edge_index)
if i == self.num_layers - 1 and self.jk_mode is None:
break
if self.act is not None and self.act_first:
x = self.act(x)
if self.norms is not None:
x = self.norms[i](x)
if self.act is not None and not self.act_first:
x = self.act(x)
x = F.dropout(x, p=self.dropout, training=self.training)
if hasattr(self, 'jk'):
xs.append(x)
x = self.jk(xs) if hasattr(self, 'jk') else x
x = self.lin(x) if hasattr(self, 'lin') else x
return x
@torch.no_grad()
def inference(self, loader: NeighborLoader,
device: Optional[torch.device] = None,
progress_bar: bool = False) -> Tensor:
r"""Performs layer-wise inference on large-graphs using a
:class:`~torch_geometric.loader.NeighborLoader`, where
:class:`~torch_geometric.loader.NeighborLoader` should sample the
full neighborhood for only one layer.
This is an efficient way to compute the output embeddings for all
nodes in the graph.
Only applicable in case :obj:`jk=None` or `jk='last'`.
"""
assert self.jk_mode is None or self.jk_mode == 'last'
assert isinstance(loader, NeighborLoader)
assert len(loader.dataset) == loader.data.num_nodes
assert len(loader.node_sampler.num_neighbors) == 1
assert not self.training
# assert not loader.shuffle # TODO (matthias) does not work :(
if progress_bar:
pbar = tqdm(total=len(self.convs) * len(loader))
pbar.set_description('Inference')
x_all = loader.data.x.cpu()
loader.data.n_id = torch.arange(x_all.size(0))
for i in range(self.num_layers):
xs: List[Tensor] = []
for batch in loader:
x = x_all[batch.n_id].to(device)
if hasattr(batch, 'adj_t'):
edge_index = batch.adj_t.to(device)
else:
edge_index = batch.edge_index.to(device)
x = self.convs[i](x, edge_index)[:batch.batch_size]
if i == self.num_layers - 1 and self.jk_mode is None:
xs.append(x.cpu())
if progress_bar:
pbar.update(1)
continue
if self.act is not None and self.act_first:
x = self.act(x)
if self.norms is not None:
x = self.norms[i](x)
if self.act is not None and not self.act_first:
x = self.act(x)
if i == self.num_layers - 1 and hasattr(self, 'lin'):
x = self.lin(x)
xs.append(x.cpu())
if progress_bar:
pbar.update(1)
x_all = torch.cat(xs, dim=0)
if progress_bar:
pbar.close()
del loader.data.n_id
return x_all
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, num_layers={self.num_layers})')
[docs]class GCN(BasicGNN):
r"""The Graph Neural Network from the `"Semi-supervised
Classification with Graph Convolutional Networks"
<https://arxiv.org/abs/1609.02907>`_ paper, using the
:class:`~torch_geometric.nn.conv.GCNConv` operator for message passing.
Args:
in_channels (int): Size of each input sample, or :obj:`-1` to derive
the size from the first input(s) to the forward method.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
node embeddings to the expected output feature dimensionality.
(:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
:obj:`"lstm"`). (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GCNConv`.
"""
supports_edge_weight = True
supports_edge_attr = False
def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
return GCNConv(in_channels, out_channels, **kwargs)
[docs]class GraphSAGE(BasicGNN):
r"""The Graph Neural Network from the `"Inductive Representation Learning
on Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper, using the
:class:`~torch_geometric.nn.SAGEConv` operator for message passing.
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
node embeddings to the expected output feature dimensionality.
(:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
:obj:`"lstm"`). (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.SAGEConv`.
"""
supports_edge_weight = False
supports_edge_attr = False
def init_conv(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, **kwargs) -> MessagePassing:
return SAGEConv(in_channels, out_channels, **kwargs)
[docs]class GIN(BasicGNN):
r"""The Graph Neural Network from the `"How Powerful are Graph Neural
Networks?" <https://arxiv.org/abs/1810.00826>`_ paper, using the
:class:`~torch_geometric.nn.GINConv` operator for message passing.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
node embeddings to the expected output feature dimensionality.
(:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
:obj:`"lstm"`). (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GINConv`.
"""
supports_edge_weight = False
supports_edge_attr = False
def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
mlp = MLP(
[in_channels, out_channels, out_channels],
act=self.act,
act_first=self.act_first,
norm=self.norm,
norm_kwargs=self.norm_kwargs,
)
return GINConv(mlp, **kwargs)
[docs]class GAT(BasicGNN):
r"""The Graph Neural Network from `"Graph Attention Networks"
<https://arxiv.org/abs/1710.10903>`_ or `"How Attentive are Graph Attention
Networks?" <https://arxiv.org/abs/2105.14491>`_ papers, using the
:class:`~torch_geometric.nn.GATConv` or
:class:`~torch_geometric.nn.GATv2Conv` operator for message passing,
respectively.
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
v2 (bool, optional): If set to :obj:`True`, will make use of
:class:`~torch_geometric.nn.conv.GATv2Conv` rather than
:class:`~torch_geometric.nn.conv.GATConv`. (default: :obj:`False`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
node embeddings to the expected output feature dimensionality.
(:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
:obj:`"lstm"`). (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GATConv` or
:class:`torch_geometric.nn.conv.GATv2Conv`.
"""
supports_edge_weight = False
supports_edge_attr = True
def init_conv(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, **kwargs) -> MessagePassing:
v2 = kwargs.pop('v2', False)
heads = kwargs.pop('heads', 1)
concat = kwargs.pop('concat', True)
# Do not use concatenation in case the layer `GATConv` layer maps to
# the desired output channels (out_channels != None and jk != None):
if getattr(self, '_is_conv_to_out', False):
concat = False
if concat and out_channels % heads != 0:
raise ValueError(f"Ensure that the number of output channels of "
f"'GATConv' (got '{out_channels}') is divisible "
f"by the number of heads (got '{heads}')")
if concat:
out_channels = out_channels // heads
Conv = GATConv if not v2 else GATv2Conv
return Conv(in_channels, out_channels, heads=heads, concat=concat,
dropout=self.dropout, **kwargs)
[docs]class PNA(BasicGNN):
r"""The Graph Neural Network from the `"Principal Neighbourhood Aggregation
for Graph Nets" <https://arxiv.org/abs/2004.05718>`_ paper, using the
:class:`~torch_geometric.nn.conv.PNAConv` operator for message passing.
Args:
in_channels (int): Size of each input sample, or :obj:`-1` to derive
the size from the first input(s) to the forward method.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
node embeddings to the expected output feature dimensionality.
(:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
:obj:`"lstm"`). (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.PNAConv`.
"""
supports_edge_weight = False
supports_edge_attr = True
def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
return PNAConv(in_channels, out_channels, **kwargs)
[docs]class EdgeCNN(BasicGNN):
r"""The Graph Neural Network from the `"Dynamic Graph CNN for Learning on
Point Clouds" <https://arxiv.org/abs/1801.07829>`_ paper, using the
:class:`~torch_geometric.nn.conv.EdgeConv` operator for message passing.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden sample.
num_layers (int): Number of message passing layers.
out_channels (int, optional): If not set to :obj:`None`, will apply a
final linear transformation to convert hidden node embeddings to
output size :obj:`out_channels`. (default: :obj:`None`)
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
node embeddings to the expected output feature dimensionality.
(:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
:obj:`"lstm"`). (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.EdgeConv`.
"""
supports_edge_weight = False
supports_edge_attr = False
def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
mlp = MLP(
[2 * in_channels, out_channels, out_channels],
act=self.act,
act_first=self.act_first,
norm=self.norm,
norm_kwargs=self.norm_kwargs,
)
return EdgeConv(mlp, **kwargs)
__all__ = ['GCN', 'GraphSAGE', 'GIN', 'GAT', 'PNA', 'EdgeCNN']