Source code for torch_geometric.nn.models.basic_gnn

import copy
import inspect
from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Linear, ModuleList
from tqdm import tqdm

from torch_geometric.data import Data
from torch_geometric.loader import CachedLoader, NeighborLoader
from torch_geometric.nn.conv import (
    EdgeConv,
    GATConv,
    GATv2Conv,
    GCNConv,
    GINConv,
    MessagePassing,
    PNAConv,
    SAGEConv,
)
from torch_geometric.nn.models import MLP
from torch_geometric.nn.models.jumping_knowledge import JumpingKnowledge
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils._trim_to_layer import TrimToLayer


class BasicGNN(torch.nn.Module):
    r"""An abstract class for implementing basic GNN models.

    Args:
        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
            derive the size from the first input(s) to the forward method.
            A tuple corresponds to the sizes of source and target
            dimensionalities.
        hidden_channels (int): Size of each hidden sample.
        num_layers (int): Number of message passing layers.
        out_channels (int, optional): If not set to :obj:`None`, will apply a
            final linear transformation to convert hidden node embeddings to
            output size :obj:`out_channels`. (default: :obj:`None`)
        dropout (float, optional): Dropout probability. (default: :obj:`0.`)
        act (str or Callable, optional): The non-linear activation function to
            use. (default: :obj:`"relu"`)
        act_first (bool, optional): If set to :obj:`True`, activation is
            applied before normalization. (default: :obj:`False`)
        act_kwargs (Dict[str, Any], optional): Arguments passed to the
            respective activation function defined by :obj:`act`.
            (default: :obj:`None`)
        norm (str or Callable, optional): The normalization function to
            use. (default: :obj:`None`)
        norm_kwargs (Dict[str, Any], optional): Arguments passed to the
            respective normalization function defined by :obj:`norm`.
            (default: :obj:`None`)
        jk (str, optional): The Jumping Knowledge mode. If specified, the model
            will additionally apply a final linear transformation to transform
            node embeddings to the expected output feature dimensionality.
            (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
            :obj:`"lstm"`). (default: :obj:`None`)
        **kwargs (optional): Additional arguments of the underlying
            :class:`torch_geometric.nn.conv.MessagePassing` layers.
    """
    supports_edge_weight: Final[bool]
    supports_edge_attr: Final[bool]
    supports_norm_batch: Final[bool]

    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        num_layers: int,
        out_channels: Optional[int] = None,
        dropout: float = 0.0,
        act: Union[str, Callable, None] = "relu",
        act_first: bool = False,
        act_kwargs: Optional[Dict[str, Any]] = None,
        norm: Union[str, Callable, None] = None,
        norm_kwargs: Optional[Dict[str, Any]] = None,
        jk: Optional[str] = None,
        **kwargs,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers

        self.dropout = torch.nn.Dropout(p=dropout)
        self.act = activation_resolver(act, **(act_kwargs or {}))
        self.jk_mode = jk
        self.act_first = act_first
        self.norm = norm if isinstance(norm, str) else None
        self.norm_kwargs = norm_kwargs

        if out_channels is not None:
            self.out_channels = out_channels
        else:
            self.out_channels = hidden_channels

        self.convs = ModuleList()
        if num_layers > 1:
            self.convs.append(
                self.init_conv(in_channels, hidden_channels, **kwargs))
            if isinstance(in_channels, (tuple, list)):
                in_channels = (hidden_channels, hidden_channels)
            else:
                in_channels = hidden_channels
        for _ in range(num_layers - 2):
            self.convs.append(
                self.init_conv(in_channels, hidden_channels, **kwargs))
            if isinstance(in_channels, (tuple, list)):
                in_channels = (hidden_channels, hidden_channels)
            else:
                in_channels = hidden_channels
        if out_channels is not None and jk is None:
            self._is_conv_to_out = True
            self.convs.append(
                self.init_conv(in_channels, out_channels, **kwargs))
        else:
            self.convs.append(
                self.init_conv(in_channels, hidden_channels, **kwargs))

        self.norms = ModuleList()
        norm_layer = normalization_resolver(
            norm,
            hidden_channels,
            **(norm_kwargs or {}),
        )
        if norm_layer is None:
            norm_layer = torch.nn.Identity()

        self.supports_norm_batch = False
        if hasattr(norm_layer, 'forward'):
            norm_params = inspect.signature(norm_layer.forward).parameters
            self.supports_norm_batch = 'batch' in norm_params

        for _ in range(num_layers - 1):
            self.norms.append(copy.deepcopy(norm_layer))

        if jk is not None:
            self.norms.append(copy.deepcopy(norm_layer))
        else:
            self.norms.append(torch.nn.Identity())

        if jk is not None and jk != 'last':
            self.jk = JumpingKnowledge(jk, hidden_channels, num_layers)

        if jk is not None:
            if jk == 'cat':
                in_channels = num_layers * hidden_channels
            else:
                in_channels = hidden_channels
            self.lin = Linear(in_channels, self.out_channels)

        # We define `trim_to_layer` functionality as a module such that we can
        # still use `to_hetero` on-top.
        self._trim = TrimToLayer()

    def init_conv(self, in_channels: Union[int, Tuple[int, int]],
                  out_channels: int, **kwargs) -> MessagePassing:
        raise NotImplementedError

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        for conv in self.convs:
            conv.reset_parameters()
        for norm in self.norms:
            if hasattr(norm, 'reset_parameters'):
                norm.reset_parameters()
        if hasattr(self, 'jk'):
            self.jk.reset_parameters()
        if hasattr(self, 'lin'):
            self.lin.reset_parameters()

    def forward(
        self,
        x: Tensor,
        edge_index: Adj,
        edge_weight: OptTensor = None,
        edge_attr: OptTensor = None,
        batch: OptTensor = None,
        batch_size: Optional[int] = None,
        num_sampled_nodes_per_hop: Optional[List[int]] = None,
        num_sampled_edges_per_hop: Optional[List[int]] = None,
    ) -> Tensor:
        r"""Forward pass.

        Args:
            x (torch.Tensor): The input node features.
            edge_index (torch.Tensor or SparseTensor): The edge indices.
            edge_weight (torch.Tensor, optional): The edge weights (if
                supported by the underlying GNN layer). (default: :obj:`None`)
            edge_attr (torch.Tensor, optional): The edge features (if supported
                by the underlying GNN layer). (default: :obj:`None`)
            batch (torch.Tensor, optional): The batch vector
                :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
                each element to a specific example.
                Only needs to be passed in case the underlying normalization
                layers require the :obj:`batch` information.
                (default: :obj:`None`)
            batch_size (int, optional): The number of examples :math:`B`.
                Automatically calculated if not given.
                Only needs to be passed in case the underlying normalization
                layers require the :obj:`batch` information.
                (default: :obj:`None`)
            num_sampled_nodes_per_hop (List[int], optional): The number of
                sampled nodes per hop.
                Useful in :class:`~torch_geometric.loader.NeighborLoader`
                scenarios to only operate on minimal-sized representations.
                (default: :obj:`None`)
            num_sampled_edges_per_hop (List[int], optional): The number of
                sampled edges per hop.
                Useful in :class:`~torch_geometric.loader.NeighborLoader`
                scenarios to only operate on minimal-sized representations.
                (default: :obj:`None`)
        """
        if (num_sampled_nodes_per_hop is not None
                and isinstance(edge_weight, Tensor)
                and isinstance(edge_attr, Tensor)):
            raise NotImplementedError("'trim_to_layer' functionality does not "
                                      "yet support trimming of both "
                                      "'edge_weight' and 'edge_attr'")

        xs: List[Tensor] = []
        assert len(self.convs) == len(self.norms)
        for i, (conv, norm) in enumerate(zip(self.convs, self.norms)):
            if (not torch.jit.is_scripting()
                    and num_sampled_nodes_per_hop is not None):
                x, edge_index, value = self._trim(
                    i,
                    num_sampled_nodes_per_hop,
                    num_sampled_edges_per_hop,
                    x,
                    edge_index,
                    edge_weight if edge_weight is not None else edge_attr,
                )
                if edge_weight is not None:
                    edge_weight = value
                else:
                    edge_attr = value

            # Tracing the module is not allowed with *args and **kwargs :(
            # As such, we rely on a static solution to pass optional edge
            # weights and edge attributes to the module.
            if self.supports_edge_weight and self.supports_edge_attr:
                x = conv(x, edge_index, edge_weight=edge_weight,
                         edge_attr=edge_attr)
            elif self.supports_edge_weight:
                x = conv(x, edge_index, edge_weight=edge_weight)
            elif self.supports_edge_attr:
                x = conv(x, edge_index, edge_attr=edge_attr)
            else:
                x = conv(x, edge_index)

            if i < self.num_layers - 1 or self.jk_mode is not None:
                if self.act is not None and self.act_first:
                    x = self.act(x)
                if self.supports_norm_batch:
                    x = norm(x, batch, batch_size)
                else:
                    x = norm(x)
                if self.act is not None and not self.act_first:
                    x = self.act(x)
                x = self.dropout(x)
                if hasattr(self, 'jk'):
                    xs.append(x)

        x = self.jk(xs) if hasattr(self, 'jk') else x
        x = self.lin(x) if hasattr(self, 'lin') else x

        return x

    @torch.no_grad()
    def inference_per_layer(
        self,
        layer: int,
        x: Tensor,
        edge_index: Adj,
        batch_size: int,
    ) -> Tensor:

        x = self.convs[layer](x, edge_index)[:batch_size]

        if layer == self.num_layers - 1 and self.jk_mode is None:
            return x

        if self.act is not None and self.act_first:
            x = self.act(x)
        if self.norms is not None:
            x = self.norms[layer](x)
        if self.act is not None and not self.act_first:
            x = self.act(x)
        if layer == self.num_layers - 1 and hasattr(self, 'lin'):
            x = self.lin(x)

        return x

    @torch.no_grad()
    def inference(
        self,
        loader: NeighborLoader,
        device: Optional[Union[str, torch.device]] = None,
        embedding_device: Union[str, torch.device] = 'cpu',
        progress_bar: bool = False,
        cache: bool = False,
    ) -> Tensor:
        r"""Performs layer-wise inference on large-graphs using a
        :class:`~torch_geometric.loader.NeighborLoader`, where
        :class:`~torch_geometric.loader.NeighborLoader` should sample the
        full neighborhood for only one layer.
        This is an efficient way to compute the output embeddings for all
        nodes in the graph.
        Only applicable in case :obj:`jk=None` or `jk='last'`.

        Args:
            loader (torch_geometric.loader.NeighborLoader): A neighbor loader
                object that generates full 1-hop subgraphs, *i.e.*,
                :obj:`loader.num_neighbors = [-1]`.
            device (torch.device, optional): The device to run the GNN on.
                (default: :obj:`None`)
            embedding_device (torch.device, optional): The device to store
                intermediate embeddings on. If intermediate embeddings fit on
                GPU, this option helps to avoid unnecessary device transfers.
                (default: :obj:`"cpu"`)
            progress_bar (bool, optional): If set to :obj:`True`, will print a
                progress bar during computation. (default: :obj:`False`)
            cache (bool, optional): If set to :obj:`True`, caches intermediate
                sampler outputs for usage in later epochs.
                This will avoid repeated sampling to accelerate inference.
                (default: :obj:`False`)
        """
        assert self.jk_mode is None or self.jk_mode == 'last'
        assert isinstance(loader, NeighborLoader)
        assert len(loader.dataset) == loader.data.num_nodes
        assert len(loader.node_sampler.num_neighbors) == 1
        assert not self.training
        # assert not loader.shuffle  # TODO (matthias) does not work :(
        if progress_bar:
            pbar = tqdm(total=len(self.convs) * len(loader))
            pbar.set_description('Inference')

        x_all = loader.data.x.to(embedding_device)

        if cache:

            # Only cache necessary attributes:
            def transform(data: Data) -> Data:
                kwargs = dict(n_id=data.n_id, batch_size=data.batch_size)
                if hasattr(data, 'adj_t'):
                    kwargs['adj_t'] = data.adj_t
                else:
                    kwargs['edge_index'] = data.edge_index

                return Data.from_dict(kwargs)

            loader = CachedLoader(loader, device=device, transform=transform)

        for i in range(self.num_layers):
            xs: List[Tensor] = []
            for batch in loader:
                x = x_all[batch.n_id].to(device)
                batch_size = batch.batch_size
                if hasattr(batch, 'adj_t'):
                    edge_index = batch.adj_t.to(device)
                else:
                    edge_index = batch.edge_index.to(device)

                x = self.inference_per_layer(i, x, edge_index, batch_size)
                xs.append(x.to(embedding_device))

                if progress_bar:
                    pbar.update(1)

            x_all = torch.cat(xs, dim=0)

        if progress_bar:
            pbar.close()

        return x_all

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, num_layers={self.num_layers})')


[docs]class GCN(BasicGNN): r"""The Graph Neural Network from the `"Semi-supervised Classification with Graph Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_ paper, using the :class:`~torch_geometric.nn.conv.GCNConv` operator for message passing. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GCNConv`. """ supports_edge_weight: Final[bool] = True supports_edge_attr: Final[bool] = False supports_norm_batch: Final[bool] def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: return GCNConv(in_channels, out_channels, **kwargs)
[docs]class GraphSAGE(BasicGNN): r"""The Graph Neural Network from the `"Inductive Representation Learning on Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper, using the :class:`~torch_geometric.nn.SAGEConv` operator for message passing. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.SAGEConv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = False supports_norm_batch: Final[bool] def init_conv(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs) -> MessagePassing: return SAGEConv(in_channels, out_channels, **kwargs)
[docs]class GIN(BasicGNN): r"""The Graph Neural Network from the `"How Powerful are Graph Neural Networks?" <https://arxiv.org/abs/1810.00826>`_ paper, using the :class:`~torch_geometric.nn.GINConv` operator for message passing. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GINConv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = False supports_norm_batch: Final[bool] def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: mlp = MLP( [in_channels, out_channels, out_channels], act=self.act, act_first=self.act_first, norm=self.norm, norm_kwargs=self.norm_kwargs, ) return GINConv(mlp, **kwargs)
[docs]class GAT(BasicGNN): r"""The Graph Neural Network from `"Graph Attention Networks" <https://arxiv.org/abs/1710.10903>`_ or `"How Attentive are Graph Attention Networks?" <https://arxiv.org/abs/2105.14491>`_ papers, using the :class:`~torch_geometric.nn.GATConv` or :class:`~torch_geometric.nn.GATv2Conv` operator for message passing, respectively. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) v2 (bool, optional): If set to :obj:`True`, will make use of :class:`~torch_geometric.nn.conv.GATv2Conv` rather than :class:`~torch_geometric.nn.conv.GATConv`. (default: :obj:`False`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.GATConv` or :class:`torch_geometric.nn.conv.GATv2Conv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = True supports_norm_batch: Final[bool] def init_conv(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs) -> MessagePassing: v2 = kwargs.pop('v2', False) heads = kwargs.pop('heads', 1) concat = kwargs.pop('concat', True) # Do not use concatenation in case the layer `GATConv` layer maps to # the desired output channels (out_channels != None and jk != None): if getattr(self, '_is_conv_to_out', False): concat = False if concat and out_channels % heads != 0: raise ValueError(f"Ensure that the number of output channels of " f"'GATConv' (got '{out_channels}') is divisible " f"by the number of heads (got '{heads}')") if concat: out_channels = out_channels // heads Conv = GATConv if not v2 else GATv2Conv return Conv(in_channels, out_channels, heads=heads, concat=concat, dropout=self.dropout.p, **kwargs)
[docs]class PNA(BasicGNN): r"""The Graph Neural Network from the `"Principal Neighbourhood Aggregation for Graph Nets" <https://arxiv.org/abs/2004.05718>`_ paper, using the :class:`~torch_geometric.nn.conv.PNAConv` operator for message passing. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.PNAConv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = True supports_norm_batch: Final[bool] def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: return PNAConv(in_channels, out_channels, **kwargs)
[docs]class EdgeCNN(BasicGNN): r"""The Graph Neural Network from the `"Dynamic Graph CNN for Learning on Point Clouds" <https://arxiv.org/abs/1801.07829>`_ paper, using the :class:`~torch_geometric.nn.conv.EdgeConv` operator for message passing. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. num_layers (int): Number of message passing layers. out_channels (int, optional): If not set to :obj:`None`, will apply a final linear transformation to convert hidden node embeddings to output size :obj:`out_channels`. (default: :obj:`None`) dropout (float, optional): Dropout probability. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_first (bool, optional): If set to :obj:`True`, activation is applied before normalization. (default: :obj:`False`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`None`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) jk (str, optional): The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"lstm"`). (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.EdgeConv`. """ supports_edge_weight: Final[bool] = False supports_edge_attr: Final[bool] = False supports_norm_batch: Final[bool] def init_conv(self, in_channels: int, out_channels: int, **kwargs) -> MessagePassing: mlp = MLP( [2 * in_channels, out_channels, out_channels], act=self.act, act_first=self.act_first, norm=self.norm, norm_kwargs=self.norm_kwargs, ) return EdgeConv(mlp, **kwargs)
__all__ = [ 'GCN', 'GraphSAGE', 'GIN', 'GAT', 'PNA', 'EdgeCNN', ]