Source code for torch_geometric.nn.models.basic_gnn

import copy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear, ModuleList
from tqdm import tqdm

from torch_geometric.loader import NeighborLoader
from torch_geometric.nn.conv import (
    EdgeConv,
    GATConv,
    GATv2Conv,
    GCNConv,
    GINConv,
    MessagePassing,
    PNAConv,
    SAGEConv,
)
from torch_geometric.nn.models import MLP
from torch_geometric.nn.models.jumping_knowledge import JumpingKnowledge
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import Adj, OptTensor


class BasicGNN(torch.nn.Module):
    r"""An abstract class for implementing basic GNN models.

    Args:
        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
            derive the size from the first input(s) to the forward method.
            A tuple corresponds to the sizes of source and target
            dimensionalities.
        hidden_channels (int): Size of each hidden sample.
        num_layers (int): Number of message passing layers.
        out_channels (int, optional): If not set to :obj:`None`, will apply a
            final linear transformation to convert hidden node embeddings to
            output size :obj:`out_channels`. (default: :obj:`None`)
        dropout (float, optional): Dropout probability. (default: :obj:`0.`)
        act (str or Callable, optional): The non-linear activation function to
            use. (default: :obj:`"relu"`)
        act_first (bool, optional): If set to :obj:`True`, activation is
            applied before normalization. (default: :obj:`False`)
        act_kwargs (Dict[str, Any], optional): Arguments passed to the
            respective activation function defined by :obj:`act`.
            (default: :obj:`None`)
        norm (str or Callable, optional): The normalization function to
            use. (default: :obj:`None`)
        norm_kwargs (Dict[str, Any], optional): Arguments passed to the
            respective normalization function defined by :obj:`norm`.
            (default: :obj:`None`)
        jk (str, optional): The Jumping Knowledge mode. If specified, the model
            will additionally apply a final linear transformation to transform
            node embeddings to the expected output feature dimensionality.
            (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
            :obj:`"lstm"`). (default: :obj:`None`)
        **kwargs (optional): Additional arguments of the underlying
            :class:`torch_geometric.nn.conv.MessagePassing` layers.
    """
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        num_layers: int,
        out_channels: Optional[int] = None,
        dropout: float = 0.0,
        act: Union[str, Callable, None] = "relu",
        act_first: bool = False,
        act_kwargs: Optional[Dict[str, Any]] = None,
        norm: Union[str, Callable, None] = None,
        norm_kwargs: Optional[Dict[str, Any]] = None,
        jk: Optional[str] = None,
        **kwargs,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers

        self.dropout = dropout
        self.act = activation_resolver(act, **(act_kwargs or {}))
        self.jk_mode = jk
        self.act_first = act_first
        self.norm = norm if isinstance(norm, str) else None
        self.norm_kwargs = norm_kwargs

        if out_channels is not None:
            self.out_channels = out_channels
        else:
            self.out_channels = hidden_channels

        self.convs = ModuleList()
        if num_layers > 1:
            self.convs.append(
                self.init_conv(in_channels, hidden_channels, **kwargs))
            if isinstance(in_channels, (tuple, list)):
                in_channels = (hidden_channels, hidden_channels)
            else:
                in_channels = hidden_channels
        for _ in range(num_layers - 2):
            self.convs.append(
                self.init_conv(in_channels, hidden_channels, **kwargs))
            if isinstance(in_channels, (tuple, list)):
                in_channels = (hidden_channels, hidden_channels)
            else:
                in_channels = hidden_channels
        if out_channels is not None and jk is None:
            self._is_conv_to_out = True
            self.convs.append(
                self.init_conv(in_channels, out_channels, **kwargs))
        else:
            self.convs.append(
                self.init_conv(in_channels, hidden_channels, **kwargs))

        self.norms = None
        if norm is not None:
            norm_layer = normalization_resolver(
                norm,
                hidden_channels,
                **(norm_kwargs or {}),
            )
            self.norms = ModuleList()
            for _ in range(num_layers - 1):
                self.norms.append(copy.deepcopy(norm_layer))
            if jk is not None:
                self.norms.append(copy.deepcopy(norm_layer))

        if jk is not None and jk != 'last':
            self.jk = JumpingKnowledge(jk, hidden_channels, num_layers)

        if jk is not None:
            if jk == 'cat':
                in_channels = num_layers * hidden_channels
            else:
                in_channels = hidden_channels
            self.lin = Linear(in_channels, self.out_channels)

    def init_conv(self, in_channels: Union[int, Tuple[int, int]],
                  out_channels: int, **kwargs) -> MessagePassing:
        raise NotImplementedError

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for norm in self.norms or []:
            norm.reset_parameters()
        if hasattr(self, 'jk'):
            self.jk.reset_parameters()
        if hasattr(self, 'lin'):
            self.lin.reset_parameters()

    def forward(
        self,
        x: Tensor,
        edge_index: Adj,
        *,
        edge_weight: OptTensor = None,
        edge_attr: OptTensor = None,
    ) -> Tensor:
        """"""
        xs: List[Tensor] = []
        for i in range(self.num_layers):
            # Tracing the module is not allowed with *args and **kwargs :(
            # As such, we rely on a static solution to pass optional edge
            # weights and edge attributes to the module.
            if self.supports_edge_weight and self.supports_edge_attr:
                x = self.convs[i](x, edge_index, edge_weight=edge_weight,
                                  edge_attr=edge_attr)
            elif self.supports_edge_weight:
                x = self.convs[i](x, edge_index, edge_weight=edge_weight)
            elif self.supports_edge_attr:
                x = self.convs[i](x, edge_index, edge_attr=edge_attr)
            else:
                x = self.convs[i](x, edge_index)
            if i == self.num_layers - 1 and self.jk_mode is None:
                break
            if self.act is not None and self.act_first:
                x = self.act(x)
            if self.norms is not None:
                x = self.norms[i](x)
            if self.act is not None and not self.act_first:
                x = self.act(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            if hasattr(self, 'jk'):
                xs.append(x)

        x = self.jk(xs) if hasattr(self, 'jk') else x
        x = self.lin(x) if hasattr(self, 'lin') else x
        return x

    @torch.no_grad()
    def inference(self, loader: NeighborLoader,
                  device: Optional[torch.device] = None,
                  progress_bar: bool = False) -> Tensor:
        r"""Performs layer-wise inference on large-graphs using
        :class:`~torch_geometric.loader.NeighborLoader`.
        :class:`~torch_geometric.loader.NeighborLoader` should sample the the
        full neighborhood for only one layer.
        This is an efficient way to compute the output embeddings for all
        nodes in the graph.
        Only applicable in case :obj:`jk=None` or `jk='last'`.
        """
        assert self.jk_mode is None or self.jk_mode == 'last'
        assert isinstance(loader, NeighborLoader)
        assert len(loader.dataset) == loader.data.num_nodes
        assert len(loader.num_neighbors) == 1
        assert not self.training
        # assert not loader.shuffle  # TODO (matthias) does not work :(
        if progress_bar:
            pbar = tqdm(total=len(self.convs) * len(loader))
            pbar.set_description('Inference')

        x_all = loader.data.x.cpu()
        loader.data.n_id = torch.arange(x_all.size(0))

        for i in range(self.num_layers):
            xs: List[Tensor] = []
            for batch in loader:
                x = x_all[batch.n_id].to(device)
                edge_index = batch.edge_index.to(device)
                x = self.convs[i](x, edge_index)[:batch.batch_size]
                if i == self.num_layers - 1 and self.jk_mode is None:
                    xs.append(x.cpu())
                    if progress_bar:
                        pbar.update(1)
                    continue
                if self.act is not None and self.act_first:
                    x = self.act(x)
                if self.norms is not None:
                    x = self.norms[i](x)
                if self.act is not None and not self.act_first:
                    x = self.act(x)
                if i == self.num_layers - 1 and hasattr(self, 'lin'):
                    x = self.lin(x)
                xs.append(x.cpu())
                if progress_bar:
                    pbar.update(1)
            x_all = torch.cat(xs, dim=0)
        if progress_bar:
            pbar.close()
        del loader.data.n_id

        return x_all

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, num_layers={self.num_layers})')


[docs]class GCN(BasicGNN): r"""The Graph Neural Network from the `"Semi-supervised Classification with Graph Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_ paper, using the :class:`~torch_geometric.nn.conv.GCNConv` operator for message passing. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GCNConv`. """ supports_edge_weight = True supports_edge_attr = False def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: return GCNConv(in_channels, out_channels, **kwargs)
[docs]class GraphSAGE(BasicGNN): r"""The Graph Neural Network from the `"Inductive Representation Learning on Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper, using the :class:`~torch_geometric.nn.SAGEConv` operator for message passing. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.SAGEConv`. """ supports_edge_weight = False supports_edge_attr = False def init_conv(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs) -> MessagePassing: return SAGEConv(in_channels, out_channels, **kwargs)
[docs]class GIN(BasicGNN): r"""The Graph Neural Network from the `"How Powerful are Graph Neural Networks?" <https://arxiv.org/abs/1810.00826>`_ paper, using the :class:`~torch_geometric.nn.GINConv` operator for message passing. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GINConv`. """ supports_edge_weight = False supports_edge_attr = False def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: mlp = MLP( [in_channels, out_channels, out_channels], act=self.act, act_first=self.act_first, norm=self.norm, norm_kwargs=self.norm_kwargs, ) return GINConv(mlp, **kwargs)
[docs]class GAT(BasicGNN): r"""The Graph Neural Network from `"Graph Attention Networks" <https://arxiv.org/abs/1710.10903>`_ or `"How Attentive are Graph Attention Networks?" <https://arxiv.org/abs/2105.14491>`_ papers, using the :class:`~torch_geometric.nn.GATConv` or :class:`~torch_geometric.nn.GATv2Conv` operator for message passing, respectively. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) v2 (bool, optional): If set to :obj:`True`, will make use of :class:`~torch_geometric.nn.conv.GATv2Conv` rather than :class:`~torch_geometric.nn.conv.GATConv`. (default: :obj:`False`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GATConv` or :class:`torch_geometric.nn.conv.GATv2Conv`. """ supports_edge_weight = False supports_edge_attr = True def init_conv(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs) -> MessagePassing: v2 = kwargs.pop('v2', False) heads = kwargs.pop('heads', 1) concat = kwargs.pop('concat', True) # Do not use concatenation in case the layer `GATConv` layer maps to # the desired output channels (out_channels != None and jk != None): if getattr(self, '_is_conv_to_out', False): concat = False if concat and out_channels % heads != 0: raise ValueError(f"Ensure that the number of output channels of " f"'GATConv' (got '{out_channels}') is divisible " f"by the number of heads (got '{heads}')") if concat: out_channels = out_channels // heads Conv = GATConv if not v2 else GATv2Conv return Conv(in_channels, out_channels, heads=heads, concat=concat, dropout=self.dropout, **kwargs)
[docs]class PNA(BasicGNN): r"""The Graph Neural Network from the `"Principal Neighbourhood Aggregation for Graph Nets" <https://arxiv.org/abs/2004.05718>`_ paper, using the :class:`~torch_geometric.nn.conv.PNAConv` operator for message passing. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.PNAConv`. """ supports_edge_weight = False supports_edge_attr = True def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: return PNAConv(in_channels, out_channels, **kwargs)
[docs]class EdgeCNN(BasicGNN): r"""The Graph Neural Network from the `"Dynamic Graph CNN for Learning on Point Clouds" <https://arxiv.org/abs/1801.07829>`_ paper, using the :class:`~torch_geometric.nn.conv.EdgeConv` operator for message passing. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.EdgeConv`. """ supports_edge_weight = False supports_edge_attr = False def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: mlp = MLP( [2 * in_channels, out_channels, out_channels], act=self.act, act_first=self.act_first, norm=self.norm, norm_kwargs=self.norm_kwargs, ) return EdgeConv(mlp, **kwargs)
__all__ = ['GCN', 'GraphSAGE', 'GIN', 'GAT', 'PNA', 'EdgeCNN']