torch_geometric.nn

class Sequential(input_args: str, modules: List[Union[Tuple[Callable, str], Callable]])[source]

An extension of the torch.nn.Sequential container in order to define a sequential GNN model. Since GNN operators take in multiple input arguments, torch_geometric.nn.Sequential expects both global input arguments, and function header definitions of individual operators. If omitted, an intermediate module will operate on the output of its preceding module:

from torch.nn import Linear, ReLU
from torch_geometric.nn import Sequential, GCNConv

model = Sequential('x, edge_index', [
    (GCNConv(in_channels, 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (GCNConv(64, 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    Linear(64, out_channels),
])

where 'x, edge_index' defines the input arguments of model, and 'x, edge_index -> x' defines the function header, i.e. input arguments and return types, of GCNConv.

In particular, this also allows to create more sophisticated models, such as utilizing JumpingKnowledge:

from torch.nn import Linear, ReLU, Dropout
from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge
from torch_geometric.nn import global_mean_pool

model = Sequential('x, edge_index, batch', [
    (Dropout(p=0.5), 'x -> x'),
    (GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'),
    ReLU(inplace=True),
    (GCNConv(64, 64), 'x1, edge_index -> x2'),
    ReLU(inplace=True),
    (lambda x1, x2: [x1, x2], 'x1, x2 -> xs'),
    (JumpingKnowledge("cat", 64, num_layers=2), 'xs -> x'),
    (global_mean_pool, 'x, batch -> x'),
    Linear(2 * 64, dataset.num_classes),
])
Parameters
  • input_args (str) – The input arguments of the model.

  • modules ([(str, Callable) or Callable]) – A list of modules (with optional function header definitions). Alternatively, an OrderedDict of modules (and function header definitions) can be passed.

class Linear(in_channels: int, out_channels: int, bias: bool = True, weight_initializer: Optional[str] = None, bias_initializer: Optional[str] = None)[source]

Applies a linear tranformation to the incoming data

\[\mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b}\]

similar to torch.nn.Linear. It supports lazy initialization and customizable weight and bias initialization.

Parameters
  • in_channels (int) – Size of each input sample. Will be initialized lazily in case it is given as -1.

  • out_channels (int) – Size of each output sample.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • weight_initializer (str, optional) – The initializer for the weight matrix ("glorot", "uniform", "kaiming_uniform" or None). If set to None, will match default weight initialization of torch.nn.Linear. (default: None)

  • bias_initializer (str, optional) – The initializer for the bias vector ("zeros" or None). If set to None, will match default bias initialization of torch.nn.Linear. (default: None)

Shapes:
  • input: features \((*, F_{in})\)

  • output: features \((*, F_{out})\)

reset_parameters()[source]
forward(x: Tensor) Tensor[source]
Parameters

x (Tensor) – The features.

class HeteroLinear(in_channels: int, out_channels: int, num_types: int, is_sorted: bool = False, **kwargs)[source]

Applies separate linear tranformations to the incoming data according to types

\[\mathbf{x}^{\prime}_{\kappa} = \mathbf{x}_{\kappa} \mathbf{W}^{\top}_{\kappa} + \mathbf{b}_{\kappa}\]

for type \(\kappa\). It supports lazy initialization and customizable weight and bias initialization.

Parameters
  • in_channels (int) – Size of each input sample. Will be initialized lazily in case it is given as -1.

  • out_channels (int) – Size of each output sample.

  • num_types (int) – The number of types.

  • is_sorted (bool, optional) – If set to True, assumes that type_vec is sorted. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: False)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.Linear.

Shapes:
  • input: features \((*, F_{in})\), type vector \((*)\)

  • output: features \((*, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, type_vec: Tensor) Tensor[source]
Parameters
  • x (Tensor) – The input features.

  • type_vec (LongTensor) – A vector that maps each entry to a type.

Convolutional Layers

MessagePassing

Base class for creating message passing layers of the form

GCNConv

The graph convolutional operator from the "Semi-supervised Classification with Graph Convolutional Networks" paper

ChebConv

The chebyshev spectral graph convolutional operator from the "Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering" paper

SAGEConv

The GraphSAGE operator from the "Inductive Representation Learning on Large Graphs" paper

GraphConv

The graph neural network operator from the "Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks" paper

GravNetConv

The GravNet operator from the "Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks" paper, where the graph is dynamically constructed using nearest neighbors.

GatedGraphConv

The gated graph convolution operator from the "Gated Graph Sequence Neural Networks" paper

ResGatedGraphConv

The residual gated graph convolutional operator from the "Residual Gated Graph ConvNets" paper

GATConv

The graph attentional operator from the "Graph Attention Networks" paper

FusedGATConv

The fused graph attention operator from the "Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective" paper.

GATv2Conv

The GATv2 operator from the "How Attentive are Graph Attention Networks?" paper, which fixes the static attention problem of the standard GATConv layer.

TransformerConv

The graph transformer operator from the "Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification" paper

AGNNConv

The graph attentional propagation layer from the "Attention-based Graph Neural Network for Semi-Supervised Learning" paper

TAGConv

The topology adaptive graph convolutional networks operator from the "Topology Adaptive Graph Convolutional Networks" paper

GINConv

The graph isomorphism operator from the "How Powerful are Graph Neural Networks?" paper

GINEConv

The modified GINConv operator from the "Strategies for Pre-training Graph Neural Networks" paper

ARMAConv

The ARMA graph convolutional operator from the "Graph Neural Networks with Convolutional ARMA Filters" paper

SGConv

The simple graph convolutional operator from the "Simplifying Graph Convolutional Networks" paper

SSGConv

The simple spectral graph convolutional operator from the "Simple Spectral Graph Convolution" paper

APPNP

The approximate personalized propagation of neural predictions layer from the "Predict then Propagate: Graph Neural Networks meet Personalized PageRank" paper

MFConv

The graph neural network operator from the "Convolutional Networks on Graphs for Learning Molecular Fingerprints" paper

RGCNConv

The relational graph convolutional operator from the "Modeling Relational Data with Graph Convolutional Networks" paper

FastRGCNConv

See RGCNConv.

RGATConv

The relational graph attentional operator from the "Relational Graph Attention Networks" paper.

SignedConv

The signed graph convolutional operator from the "Signed Graph Convolutional Network" paper

DNAConv

The dynamic neighborhood aggregation operator from the "Just Jump: Towards Dynamic Neighborhood Aggregation in Graph Neural Networks" paper

PointNetConv

The PointNet set layer from the "PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation" and "PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" papers

PointConv

alias of PointNetConv

GMMConv

The gaussian mixture model convolutional operator from the "Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs" paper

SplineConv

The spline-based convolutional operator from the "SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels" paper

NNConv

The continuous kernel-based convolutional operator from the "Neural Message Passing for Quantum Chemistry" paper.

ECConv

alias of NNConv

CGConv

The crystal graph convolutional operator from the "Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties" paper

EdgeConv

The edge convolutional operator from the "Dynamic Graph CNN for Learning on Point Clouds" paper

DynamicEdgeConv

The dynamic edge convolutional operator from the "Dynamic Graph CNN for Learning on Point Clouds" paper (see torch_geometric.nn.conv.EdgeConv), where the graph is dynamically constructed using nearest neighbors in the feature space.

XConv

The convolutional operator on \(\mathcal{X}\)-transformed points from the "PointCNN: Convolution On X-Transformed Points" paper

PPFConv

The PPFNet operator from the "PPFNet: Global Context Aware Local Features for Robust 3D Point Matching" paper

FeaStConv

The (translation-invariant) feature-steered convolutional operator from the "FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis" paper

PointTransformerConv

The Point Transformer layer from the "Point Transformer" paper

HypergraphConv

The hypergraph convolutional operator from the "Hypergraph Convolution and Hypergraph Attention" paper

LEConv

The local extremum graph neural network operator from the "ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations" paper, which finds the importance of nodes with respect to their neighbors using the difference operator:

PNAConv

The Principal Neighbourhood Aggregation graph convolution operator from the "Principal Neighbourhood Aggregation for Graph Nets" paper

ClusterGCNConv

The ClusterGCN graph convolutional operator from the "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" paper

GENConv

The GENeralized Graph Convolution (GENConv) from the "DeeperGCN: All You Need to Train Deeper GCNs" paper.

GCN2Conv

The graph convolutional operator with initial residual connections and identity mapping (GCNII) from the "Simple and Deep Graph Convolutional Networks" paper

PANConv

The path integral based convolutional operator from the "Path Integral Based Convolution and Pooling for Graph Neural Networks" paper

WLConv

The Weisfeiler Lehman operator from the "A Reduction of a Graph to a Canonical Form and an Algebra Arising During this Reduction" paper, which iteratively refines node colorings:

WLConvContinuous

The Weisfeiler Lehman operator from the "Wasserstein Weisfeiler-Lehman Graph Kernels" paper.

FiLMConv

The FiLM graph convolutional operator from the "GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation" paper

SuperGATConv

The self-supervised graph attentional operator from the "How to Find Your Friendly Neighborhood: Graph Attention Design with Self-Supervision" paper

FAConv

The Frequency Adaptive Graph Convolution operator from the "Beyond Low-Frequency Information in Graph Convolutional Networks" paper

EGConv

The Efficient Graph Convolution from the "Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions" paper.

PDNConv

The pathfinder discovery network convolutional operator from the "Pathfinder Discovery Networks for Neural Message Passing" paper

GeneralConv

A general GNN layer adapted from the "Design Space for Graph Neural Networks" paper.

HGTConv

The Heterogeneous Graph Transformer (HGT) operator from the "Heterogeneous Graph Transformer" paper.

HEATConv

The heterogeneous edge-enhanced graph attentional operator from the "Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction" paper, which enhances GATConv by:

HeteroConv

A generic wrapper for computing graph convolution on heterogeneous graphs.

HANConv

The Heterogenous Graph Attention Operator from the "Heterogenous Graph Attention Network" paper.

LGConv

The Light Graph Convolution (LGC) operator from the "LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation" paper

class MessagePassing(aggr: Optional[Union[str, List[str], Aggregation]] = 'add', *, aggr_kwargs: Optional[Dict[str, Any]] = None, flow: str = 'source_to_target', node_dim: int = -2, decomposed_layers: int = 1, **kwargs)[source]

Base class for creating message passing layers of the form

\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),\]

where \(\square\) denotes a differentiable, permutation invariant function, e.g., sum, mean, min, max or mul, and \(\gamma_{\mathbf{\Theta}}\) and \(\phi_{\mathbf{\Theta}}\) denote differentiable functions such as MLPs. See here for the accompanying tutorial.

Parameters
  • aggr (string or list or Aggregation, optional) – The aggregation scheme to use, e.g., "add", "sum" "mean", "min", "max" or "mul". In addition, can be any Aggregation module (or any string that automatically resolves to it). If given as a list, will make use of multiple aggregations in which different outputs will get concatenated in the last dimension. If set to None, the MessagePassing instantiation is expected to implement its own aggregation logic via aggregate(). (default: "add")

  • aggr_kwargs (Dict[str, Any], optional) – Arguments passed to the respective aggregation function in case it gets automatically resolved. (default: None)

  • flow (string, optional) – The flow direction of message passing ("source_to_target" or "target_to_source"). (default: "source_to_target")

  • node_dim (int, optional) – The axis along which to propagate. (default: -2)

  • decomposed_layers (int, optional) – The number of feature decomposition layers, as introduced in the “Optimizing Memory Efficiency of Graph Neural Networks on Edge Computing Platforms” paper. Feature decomposition reduces the peak memory usage by slicing the feature dimensions into separated feature decomposition layers during GNN aggregation. This method can accelerate GNN execution on CPU-based platforms (e.g., 2-3x speedup on the Reddit dataset) for common GNN models such as GCN, GraphSAGE, GIN, etc. However, this method is not applicable to all GNN operators available, in particular for operators in which message computation can not easily be decomposed, e.g. in attention-based GNNs. The selection of the optimal value of decomposed_layers depends both on the specific graph dataset and available hardware resources. A value of 2 is suitable in most cases. Although the peak memory usage is directly associated with the granularity of feature decomposition, the same is not necessarily true for execution speedups. (default: 1)

propagate(edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None, **kwargs)[source]

The initial call to start propagating messages.

Parameters
  • edge_index (Tensor or SparseTensor) – A torch.LongTensor, a torch_sparse.SparseTensor or a torch.sparse.Tensor that defines the underlying graph connectivity/message passing flow. edge_index holds the indices of a general (sparse) assignment matrix of shape [N, M]. If edge_index is of type torch.LongTensor, its shape must be defined as [2, num_messages], where messages from nodes in edge_index[0] are sent to nodes in edge_index[1] (in case flow="source_to_target"). If edge_index is of type torch_sparse.SparseTensor or torch.sparse.Tensor, its sparse indices (row, col) should relate to row = edge_index[1] and col = edge_index[0]. The major difference between both formats is that we need to input the transposed sparse adjacency matrix into propagate().

  • size (tuple, optional) – The size (N, M) of the assignment matrix in case edge_index is a LongTensor. If set to None, the size will be automatically inferred and assumed to be quadratic. This argument is ignored in case edge_index is a torch_sparse.SparseTensor or a torch.sparse.Tensor. (default: None)

  • **kwargs – Any additional data which is needed to construct and aggregate messages, and to update node embeddings.

edge_updater(edge_index: Union[Tensor, SparseTensor], **kwargs)[source]

The initial call to compute or update features for each edge in the graph.

Parameters
  • edge_index (Tensor or SparseTensor) – A torch.LongTensor, a torch_sparse.SparseTensor or a torch.sparse.Tensor that defines the underlying graph connectivity/message passing flow. See propagate() for more information.

  • **kwargs – Any additional data which is needed to compute or update features for each edge in the graph.

message(x_j: Tensor) Tensor[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

aggregate(inputs: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None) Tensor[source]

Aggregates messages from neighbors as \(\square_{j \in \mathcal{N}(i)}\).

Takes in the output of message computation as first argument and any argument which was initially passed to propagate().

By default, this function will delegate its call to the underlying Aggregation module to reduce messages as specified in __init__() by the aggr argument.

message_and_aggregate(adj_t: Union[SparseTensor, Tensor]) Tensor[source]

Fuses computations of message() and aggregate() into a single function. If applicable, this saves both time and memory since messages do not explicitly need to be materialized. This function will only gets called in case it is implemented and propagation takes place based on a torch_sparse.SparseTensor or a torch.sparse.Tensor.

update(inputs: Tensor) Tensor[source]

Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().

edge_update() Tensor[source]

Computes or updates features for each edge in the graph. This function can take any argument as input which was initially passed to edge_updater(). Furthermore, tensors passed to edge_updater() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

register_propagate_forward_pre_hook(hook: Callable) RemovableHandle[source]

Registers a forward pre-hook on the module. The hook will be called every time before propagate() is invoked. It should have the following signature:

hook(module, inputs) -> None or modified input

The hook can modify the input. Input keyword arguments are passed to the hook as a dictionary in inputs[-1].

Returns a torch.utils.hooks.RemovableHandle that can be used to remove the added hook by calling handle.remove().

register_propagate_forward_hook(hook: Callable) RemovableHandle[source]

Registers a forward hook on the module. The hook will be called every time after propagate() has computed an output. It should have the following signature:

hook(module, inputs, output) -> None or modified output

The hook can modify the output. Input keyword arguments are passed to the hook as a dictionary in inputs[-1].

Returns a torch.utils.hooks.RemovableHandle that can be used to remove the added hook by calling handle.remove().

register_message_forward_pre_hook(hook: Callable) RemovableHandle[source]

Registers a forward pre-hook on the module. The hook will be called every time before message() is invoked. See register_propagate_forward_pre_hook() for more information.

register_message_forward_hook(hook: Callable) RemovableHandle[source]

Registers a forward hook on the module. The hook will be called every time after message() has computed an output. See register_propagate_forward_hook() for more information.

register_aggregate_forward_pre_hook(hook: Callable) RemovableHandle[source]

Registers a forward pre-hook on the module. The hook will be called every time before aggregate() is invoked. See register_propagate_forward_pre_hook() for more information.

register_aggregate_forward_hook(hook: Callable) RemovableHandle[source]

Registers a forward hook on the module. The hook will be called every time after aggregate() has computed an output. See register_propagate_forward_hook() for more information.

register_message_and_aggregate_forward_pre_hook(hook: Callable) RemovableHandle[source]

Registers a forward pre-hook on the module. The hook will be called every time before message_and_aggregate() is invoked. See register_propagate_forward_pre_hook() for more information.

register_message_and_aggregate_forward_hook(hook: Callable) RemovableHandle[source]

Registers a forward hook on the module. The hook will be called every time after message_and_aggregate() has computed an output. See register_propagate_forward_hook() for more information.

register_edge_update_forward_pre_hook(hook: Callable) RemovableHandle[source]

Registers a forward pre-hook on the module. The hook will be called every time before edge_update() is invoked. See register_propagate_forward_pre_hook() for more information.

register_edge_update_forward_hook(hook: Callable) RemovableHandle[source]

Registers a forward hook on the module. The hook will be called every time after edge_update() has computed an output. See register_propagate_forward_hook() for more information.

jittable(typing: Optional[str] = None) MessagePassing[source]

Analyzes the MessagePassing instance and produces a new jittable module.

Parameters

typing (string, optional) – If given, will generate a concrete instance with forward() types based on typing, e.g.: "(Tensor, Optional[Tensor]) -> Tensor".

class GCNConv(in_channels: int, out_channels: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, **kwargs)[source]

The graph convolutional operator from the “Semi-supervised Classification with Graph Convolutional Networks” paper

\[\mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},\]

where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than 1 representing edge weights via the optional edge_weight tensor.

Its node-wise formulation is given by:

\[\mathbf{x}^{\prime}_i = \mathbf{\Theta}^{\top} \sum_{j \in \mathcal{N}(v) \cup \{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j\]

with \(\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}\), where \(e_{j,i}\) denotes the edge weight from source node j to target node i (default: 1.0)

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • improved (bool, optional) – If set to True, the layer computes \(\mathbf{\hat{A}}\) as \(\mathbf{A} + 2\mathbf{I}\). (default: False)

  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • normalize (bool, optional) – Whether to add self-loops and compute symmetric normalization coefficients on the fly. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
class ChebConv(in_channels: int, out_channels: int, K: int, normalization: Optional[str] = 'sym', bias: bool = True, **kwargs)[source]

The chebyshev spectral graph convolutional operator from the “Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering” paper

\[\mathbf{X}^{\prime} = \sum_{k=1}^{K} \mathbf{Z}^{(k)} \cdot \mathbf{\Theta}^{(k)}\]

where \(\mathbf{Z}^{(k)}\) is computed recursively by

\[ \begin{align}\begin{aligned}\mathbf{Z}^{(1)} &= \mathbf{X}\\\mathbf{Z}^{(2)} &= \mathbf{\hat{L}} \cdot \mathbf{X}\\\mathbf{Z}^{(k)} &= 2 \cdot \mathbf{\hat{L}} \cdot \mathbf{Z}^{(k-1)} - \mathbf{Z}^{(k-2)}\end{aligned}\end{align} \]

and \(\mathbf{\hat{L}}\) denotes the scaled and normalized Laplacian \(\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}\).

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • K (int) – Chebyshev filter size \(K\).

  • normalization (str, optional) –

    The normalization scheme for the graph Laplacian (default: "sym"):

    1. None: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)

    2. "sym": Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)

    3. "rw": Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)

    lambda_max should be a torch.Tensor of size [num_graphs] in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-compute lambda_max via the torch_geometric.transforms.LaplacianLambdaMax transform.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional), batch vector \((|\mathcal{V}|)\) (optional), maximum lambda value \((|\mathcal{G}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None, batch: Optional[Tensor] = None, lambda_max: Optional[Tensor] = None) Tensor[source]
class SAGEConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, aggr: Optional[Union[str, List[str], Aggregation]] = 'mean', normalize: bool = False, root_weight: bool = True, project: bool = False, bias: bool = True, **kwargs)[source]

The GraphSAGE operator from the “Inductive Representation Learning on Large Graphs” paper

\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j\]

If project = True, then \(\mathbf{x}_j\) will first get projected via

\[\mathbf{x}_j \leftarrow \sigma ( \mathbf{W}_3 \mathbf{x}_j + \mathbf{b})\]

as described in Eq. (3) of the paper.

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • aggr (string or Aggregation, optional) – The aggregation scheme to use. Any aggregation of torch_geometric.nn.aggr can be used, e.g., "mean", "max", or "lstm". (default: "mean")

  • normalize (bool, optional) – If set to True, output features will be \(\ell_2\)-normalized, i.e., \(\frac{\mathbf{x}^{\prime}_i} {\| \mathbf{x}^{\prime}_i \|_2}\). (default: False)

  • root_weight (bool, optional) – If set to False, the layer will not add transformed root node features to the output. (default: True)

  • project (bool, optional) – If set to True, the layer will apply a linear transformation followed by an activation function before aggregation (as described in Eq. (3) of the paper). (default: False)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • inputs: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • outputs: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None) Tensor[source]
class GraphConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, aggr: str = 'add', bias: bool = True, **kwargs)[source]

The graph neural network operator from the “Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks” paper

\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{x}_j\]

where \(e_{j,i}\) denotes the edge weight from source node j to target node i (default: 1)

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "add")

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None) Tensor[source]
class GravNetConv(in_channels: int, out_channels: int, space_dimensions: int, propagate_dimensions: int, k: int, num_workers: Optional[int] = None, **kwargs)[source]

The GravNet operator from the “Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks” paper, where the graph is dynamically constructed using nearest neighbors. The neighbors are constructed in a learnable low-dimensional projection of the feature space. A second projection of the input feature space is then propagated from the neighbors to each vertex using distance weights that are derived by applying a Gaussian function to the distances.

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – The number of output channels.

  • space_dimensions (int) – The dimensionality of the space used to construct the neighbors; referred to as \(S\) in the paper.

  • propagate_dimensions (int) – The number of features to be propagated between the vertices; referred to as \(F_{\textrm{LR}}\) in the paper.

  • k (int) – The number of nearest neighbors.

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))\) if bipartite, batch vector \((|\mathcal{V}|)\) or \(((|\mathcal{V}_s|), (|\mathcal{V}_t|))\) if bipartite (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], batch: Union[Tensor, None, Tuple[Tensor, Tensor]] = None) Tensor[source]
class GatedGraphConv(out_channels: int, num_layers: int, aggr: str = 'add', bias: bool = True, **kwargs)[source]

The gated graph convolution operator from the “Gated Graph Sequence Neural Networks” paper

\[ \begin{align}\begin{aligned}\mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0}\\\mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)}\\\mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, \mathbf{h}_i^{(l)})\end{aligned}\end{align} \]

up to representation \(\mathbf{h}_i^{(L)}\). The number of input channels of \(\mathbf{x}_i\) needs to be less or equal than out_channels. \(e_{j,i}\) denotes the edge weight from source node j to target node i (default: 1)

Parameters
  • out_channels (int) – Size of each output sample.

  • num_layers (int) – The sequence length \(L\).

  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "add")

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
class ResGatedGraphConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, act: Optional[Callable] = Sigmoid(), root_weight: bool = True, bias: bool = True, **kwargs)[source]

The residual gated graph convolutional operator from the “Residual Gated Graph ConvNets” paper

\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \eta_{i,j} \odot \mathbf{W}_2 \mathbf{x}_j\]

where the gate \(\eta_{i,j}\) is defined as

\[\eta_{i,j} = \sigma(\mathbf{W}_3 \mathbf{x}_i + \mathbf{W}_4 \mathbf{x}_j)\]

with \(\sigma\) denoting the sigmoid function.

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • act (callable, optional) – Gating function \(\sigma\). (default: torch.nn.Sigmoid())

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • root_weight (bool, optional) – If set to False, the layer will not add transformed root node features to the output. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • inputs: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • outputs: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor]) Tensor[source]
class GATConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, edge_dim: Optional[int] = None, fill_value: Union[float, Tensor, str] = 'mean', bias: bool = True, **kwargs)[source]

The graph attentional operator from the “Graph Attention Networks” paper

\[\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},\]

where the attention coefficients \(\alpha_{i,j}\) are computed as

\[\alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] \right)\right)}.\]

If the graph has multi-dimensional edge features \(\mathbf{e}_{i,j}\), the attention coefficients \(\alpha_{i,j}\) are computed as

\[\alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j \, \Vert \, \mathbf{\Theta}_{e} \mathbf{e}_{i,j}]\right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k \, \Vert \, \mathbf{\Theta}_{e} \mathbf{e}_{i,k}]\right)\right)}.\]
Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default: None)

  • fill_value (float or Tensor or str, optional) – The way to generate edge features of self-loops (in case edge_dim != None). If given as float or torch.Tensor, edge features of self-loops will be directly given by fill_value. If given as str, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. ("add", "mean", "min", "max", "mul"). (default: "mean")

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, H * F_{out})\) or \(((|\mathcal{V}_t|, H * F_{out})\) if bipartite. If return_attention_weights=True, then \(((|\mathcal{V}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))\) or \(((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None, return_attention_weights=None)[source]
Parameters

return_attention_weights (bool, optional) – If set to True, will additionally return the tuple (edge_index, attention_weights), holding the computed attention weights for each edge. (default: None)

edge_update(alpha_j: Tensor, alpha_i: Optional[Tensor], edge_attr: Optional[Tensor], index: Tensor, ptr: Optional[Tensor], size_i: Optional[int]) Tensor[source]

Computes or updates features for each edge in the graph. This function can take any argument as input which was initially passed to edge_updater(). Furthermore, tensors passed to edge_updater() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

class FusedGATConv(*args, **kwargs)[source]

The fused graph attention operator from the “Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective” paper.

FusedGATConv is an optimized version of GATConv that fuses message passing computation for accelerated exeuction and lower memory footprint.

Note

This implementation is based on the dgNN package. See here for instructions on how to install.

static to_graph_format(edge_index: Tensor, size: Optional[Tuple[int, int]] = None) Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor][source]
Parameters
  • edge_index (Tensor) – The edge indices.

  • size ((int, int), optional) – dimension. (default: None)

forward(x: Tensor, csr: Tuple[Tensor, Tensor], csc: Tuple[Tensor, Tensor], perm: Tensor) Tensor[source]
Parameters
  • x (Tensor) – The node features.

  • csr – ((Tensor, Tensor)): A tuple containing the CSR representation of a graph, given as a tuple of (rowptr, col).

  • csc – ((Tensor, Tensor)): A tuple containing the CSC representation of a graph, given as a tuple of (row, colptr).

  • perm (Tensor) – Permutation tensor to map the CSR representation to the CSC representation.

Note

Use the torch_geometric.nn.conv.FusedGATConv.to_graph_format() method to obtain the (csr, csc, perm) graph format from an existing edge_index representation.

class GATv2Conv(in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, edge_dim: Optional[int] = None, fill_value: Union[float, Tensor, str] = 'mean', bias: bool = True, share_weights: bool = False, **kwargs)[source]

The GATv2 operator from the “How Attentive are Graph Attention Networks?” paper, which fixes the static attention problem of the standard GATConv layer. Since the linear layers in the standard GAT are applied right after each other, the ranking of attended nodes is unconditioned on the query node. In contrast, in GATv2, every node can attend to any other node.

\[\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},\]

where the attention coefficients \(\alpha_{i,j}\) are computed as

\[\alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_j] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_k] \right)\right)}.\]

If the graph has multi-dimensional edge features \(\mathbf{e}_{i,j}\), the attention coefficients \(\alpha_{i,j}\) are computed as

\[\alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_j \, \Vert \, \mathbf{e}_{i,j}] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_k \, \Vert \, \mathbf{e}_{i,k}] \right)\right)}.\]
Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default: None)

  • fill_value (float or Tensor or str, optional) – The way to generate edge features of self-loops (in case edge_dim != None). If given as float or torch.Tensor, edge features of self-loops will be directly given by fill_value. If given as str, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. ("add", "mean", "min", "max", "mul"). (default: "mean")

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • share_weights (bool, optional) – If set to True, the same matrix will be applied to the source and the target node of every edge. (default: False)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, H * F_{out})\) or \(((|\mathcal{V}_t|, H * F_{out})\) if bipartite. If return_attention_weights=True, then \(((|\mathcal{V}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))\) or \(((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None, return_attention_weights: Optional[bool] = None)[source]
Parameters

return_attention_weights (bool, optional) – If set to True, will additionally return the tuple (edge_index, attention_weights), holding the computed attention weights for each edge. (default: None)

class TransformerConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, beta: bool = False, dropout: float = 0.0, edge_dim: Optional[int] = None, bias: bool = True, root_weight: bool = True, **kwargs)[source]

The graph transformer operator from the “Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification” paper

\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},\]

where the attention coefficients \(\alpha_{i,j}\) are computed via multi-head dot product attention:

\[\alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)} {\sqrt{d}} \right)\]
Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • beta (bool, optional) –

    If set, will combine aggregation and skip information via

    \[\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}\]

    with \(\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top} [ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1 \mathbf{x}_i - \mathbf{m}_i ])\) (default: False)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • edge_dim (int, optional) –

    Edge feature dimensionality (in case there are any). Edge features are added to the keys after linear transformation, that is, prior to computing the attention dot product. They are also added to final values after the same linear transformation. The model is:

    \[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left( \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij} \right),\]

    where the attention coefficients \(\alpha_{i,j}\) are now computed via:

    \[\alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})} {\sqrt{d}} \right)\]

    (default None)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • root_weight (bool, optional) – If set to False, the layer will not add the transformed root node features to the output and the option beta is set to False. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None, return_attention_weights=None)[source]
Parameters

return_attention_weights (bool, optional) – If set to True, will additionally return the tuple (edge_index, attention_weights), holding the computed attention weights for each edge. (default: None)

class AGNNConv(requires_grad: bool = True, add_self_loops: bool = True, **kwargs)[source]

The graph attentional propagation layer from the “Attention-based Graph Neural Network for Semi-Supervised Learning” paper

\[\mathbf{X}^{\prime} = \mathbf{P} \mathbf{X},\]

where the propagation matrix \(\mathbf{P}\) is computed as

\[P_{i,j} = \frac{\exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_j))} {\sum_{k \in \mathcal{N}(i)\cup \{ i \}} \exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_k))}\]

with trainable parameter \(\beta\).

Parameters
Shapes:
  • input: node features \((|\mathcal{V}|, F)\), edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F)\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor]) Tensor[source]
class TAGConv(in_channels: int, out_channels: int, K: int = 3, bias: bool = True, normalize: bool = True, **kwargs)[source]

The topology adaptive graph convolutional networks operator from the “Topology Adaptive Graph Convolutional Networks” paper

\[\mathbf{X}^{\prime} = \sum_{k=0}^K \left( \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \right)^k \mathbf{X} \mathbf{W}_{k},\]

where \(\mathbf{A}\) denotes the adjacency matrix and \(D_{ii} = \sum_{j=0} A_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than 1 representing edge weights via the optional edge_weight tensor.

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • K (int, optional) – Number of hops \(K\). (default: 3)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • normalize (bool, optional) – Whether to apply symmetric normalization. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node_features \((|\mathcal{V}|, F_{in})\), edge_index \((2, |\mathcal{E}|)\), edge_weights \((|\mathcal{E}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
class GINConv(nn: Callable, eps: float = 0.0, train_eps: bool = False, **kwargs)[source]

The graph isomorphism operator from the “How Powerful are Graph Neural Networks?” paper

\[\mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)\]

or

\[\mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right),\]

here \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. an MLP.

Parameters
  • nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features x of shape [-1, in_channels] to shape [-1, out_channels], e.g., defined by torch.nn.Sequential.

  • eps (float, optional) – (Initial) \(\epsilon\)-value. (default: 0.)

  • train_eps (bool, optional) – If set to True, \(\epsilon\) will be a trainable parameter. (default: False)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None) Tensor[source]
class GINEConv(nn: Module, eps: float = 0.0, train_eps: bool = False, edge_dim: Optional[int] = None, **kwargs)[source]

The modified GINConv operator from the “Strategies for Pre-training Graph Neural Networks” paper

\[\mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathrm{ReLU} ( \mathbf{x}_j + \mathbf{e}_{j,i} ) \right)\]

that is able to incorporate edge features \(\mathbf{e}_{j,i}\) into the aggregation procedure.

Parameters
  • nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features x of shape [-1, in_channels] to shape [-1, out_channels], e.g., defined by torch.nn.Sequential.

  • eps (float, optional) – (Initial) \(\epsilon\)-value. (default: 0.)

  • train_eps (bool, optional) – If set to True, \(\epsilon\) will be a trainable parameter. (default: False)

  • edge_dim (int, optional) – Edge feature dimensionality. If set to None, node and edge feature dimensionality is expected to match. Other-wise, edge features are linearly transformed to match node feature dimensionality. (default: None)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None) Tensor[source]
class ARMAConv(in_channels: int, out_channels: int, num_stacks: int = 1, num_layers: int = 1, shared_weights: bool = False, act: Optional[Callable] = ReLU(), dropout: float = 0.0, bias: bool = True, **kwargs)[source]

The ARMA graph convolutional operator from the “Graph Neural Networks with Convolutional ARMA Filters” paper

\[\mathbf{X}^{\prime} = \frac{1}{K} \sum_{k=1}^K \mathbf{X}_k^{(T)},\]

with \(\mathbf{X}_k^{(T)}\) being recursively defined by

\[\mathbf{X}_k^{(t+1)} = \sigma \left( \mathbf{\hat{L}} \mathbf{X}_k^{(t)} \mathbf{W} + \mathbf{X}^{(0)} \mathbf{V} \right),\]

where \(\mathbf{\hat{L}} = \mathbf{I} - \mathbf{L} = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\) denotes the modified Laplacian \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\).

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample \(\mathbf{x}^{(t+1)}\).

  • num_stacks (int, optional) – Number of parallel stacks \(K\). (default: 1).

  • num_layers (int, optional) – Number of layers \(T\). (default: 1)

  • act (callable, optional) – Activation function \(\sigma\). (default: torch.nn.ReLU())

  • shared_weights (int, optional) – If set to True the layers in each stack will share the same parameters. (default: False)

  • dropout (float, optional) – Dropout probability of the skip connection. (default: 0.)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
class SGConv(in_channels: int, out_channels: int, K: int = 1, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]

The simple graph convolutional operator from the “Simplifying Graph Convolutional Networks” paper

\[\mathbf{X}^{\prime} = {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X} \mathbf{\Theta},\]

where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than 1 representing edge weights via the optional edge_weight tensor.

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • K (int, optional) – Number of hops \(K\). (default: 1)

  • cached (bool, optional) – If set to True, the layer will cache the computation of \({\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
class SSGConv(in_channels: int, out_channels: int, alpha: float, K: int = 1, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]

The simple spectral graph convolutional operator from the “Simple Spectral Graph Convolution” paper

\[\mathbf{X}^{\prime} = \frac{1}{K} \sum_{k=1}^K\left((1-\alpha) {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^k \mathbf{X}+\alpha \mathbf{X}\right) \mathbf{\Theta},\]

where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than 1 representing edge weights via the optional edge_weight tensor. SSGConv is an improved operator of SGConv by introducing the alpha parameter to address the oversmoothing issue.

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • alpha (float) – Teleport probability \(\alpha \in [0, 1]\).

  • K (int, optional) – Number of hops \(K\). (default: 1)

  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\frac{1}{K} \sum_{k=1}^K\left((1-\alpha) {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^k \mathbf{X}+ \alpha \mathbf{X}\right)\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
class APPNP(K: int, alpha: float, dropout: float = 0.0, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, **kwargs)[source]

The approximate personalized propagation of neural predictions layer from the “Predict then Propagate: Graph Neural Networks meet Personalized PageRank” paper

\[ \begin{align}\begin{aligned}\mathbf{X}^{(0)} &= \mathbf{X}\\\mathbf{X}^{(k)} &= (1 - \alpha) \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X}^{(k-1)} + \alpha \mathbf{X}^{(0)}\\\mathbf{X}^{\prime} &= \mathbf{X}^{(K)},\end{aligned}\end{align} \]

where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than 1 representing edge weights via the optional edge_weight tensor.

Parameters
  • K (int) – Number of iterations \(K\).

  • alpha (float) – Teleport probability \(\alpha\).

  • dropout (float, optional) – Dropout probability of edges during training. (default: 0)

  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • normalize (bool, optional) – Whether to add self-loops and apply symmetric normalization. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F)\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F)\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
class MFConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, max_degree: int = 10, bias=True, **kwargs)[source]

The graph neural network operator from the “Convolutional Networks on Graphs for Learning Molecular Fingerprints” paper

\[\mathbf{x}^{\prime}_i = \mathbf{W}^{(\deg(i))}_1 \mathbf{x}_i + \mathbf{W}^{(\deg(i))}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j\]

which trains a distinct weight matrix for each possible vertex degree.

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • max_degree (int, optional) – The maximum node degree to consider when updating weights (default: 10)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • inputs: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • outputs: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None) Tensor[source]
class RGCNConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, num_relations: int, num_bases: Optional[int] = None, num_blocks: Optional[int] = None, aggr: str = 'mean', root_weight: bool = True, is_sorted: bool = False, bias: bool = True, **kwargs)[source]

The relational graph convolutional operator from the “Modeling Relational Data with Graph Convolutional Networks” paper

\[\mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)} \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,\]

where \(\mathcal{R}\) denotes the set of relations, i.e. edge types. Edge type needs to be a one-dimensional torch.long tensor which stores a relation identifier \(\in \{ 0, \ldots, |\mathcal{R}| - 1\}\) for each edge.

Note

This implementation is as memory-efficient as possible by iterating over each individual relation type. Therefore, it may result in low GPU utilization in case the graph has a large number of relations. As an alternative approach, FastRGCNConv does not iterate over each individual type, but may consume a large amount of memory to compensate. We advise to check out both implementations to see which one fits your needs.

Parameters
  • in_channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities. In case no input features are given, this argument should correspond to the number of nodes in your graph.

  • out_channels (int) – Size of each output sample.

  • num_relations (int) – Number of relations.

  • num_bases (int, optional) – If set, this layer will use the basis-decomposition regularization scheme where num_bases denotes the number of bases to use. (default: None)

  • num_blocks (int, optional) – If set, this layer will use the block-diagonal-decomposition regularization scheme where num_blocks denotes the number of blocks to use. (default: None)

  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "mean")

  • root_weight (bool, optional) – If set to False, the layer will not add transformed root node features to the output. (default: True)

  • is_sorted (bool, optional) – If set to True, assumes that edge_index is sorted by edge_type. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: False)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

reset_parameters()[source]
forward(x: Union[Tensor, None, Tuple[Optional[Tensor], Tensor]], edge_index: Union[Tensor, SparseTensor], edge_type: Optional[Tensor] = None)[source]
Parameters
  • x – The input node features. Can be either a [num_nodes, in_channels] node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings). Furthermore, x can be of type tuple denoting source and destination node features.

  • edge_index (LongTensor or SparseTensor) – The edge indices.

  • edge_type – The one-dimensional relation type/index for each edge in edge_index. Should be only None in case edge_index is of type torch_sparse.tensor.SparseTensor. (default: None)

class FastRGCNConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, num_relations: int, num_bases: Optional[int] = None, num_blocks: Optional[int] = None, aggr: str = 'mean', root_weight: bool = True, is_sorted: bool = False, bias: bool = True, **kwargs)[source]

See RGCNConv.

forward(x: Union[Tensor, None, Tuple[Optional[Tensor], Tensor]], edge_index: Union[Tensor, SparseTensor], edge_type: Optional[Tensor] = None)[source]
class RGATConv(in_channels: int, out_channels: int, num_relations: int, num_bases: Optional[int] = None, num_blocks: Optional[int] = None, mod: Optional[str] = None, attention_mechanism: str = 'across-relation', attention_mode: str = 'additive-self-attention', heads: int = 1, dim: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, edge_dim: Optional[int] = None, bias: bool = True, **kwargs)[source]

The relational graph attentional operator from the “Relational Graph Attention Networks” paper. Here, attention logits \(\mathbf{a}^{(r)}_{i,j}\) are computed for each relation type \(r\) with the help of both query and key kernels, i.e.

\[\mathbf{q}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot \mathbf{Q}^{(r)} \quad \textrm{and} \quad \mathbf{k}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot \mathbf{K}^{(r)}.\]

Two schemes have been proposed to compute attention logits \(\mathbf{a}^{(r)}_{i,j}\) for each relation type \(r\):

Additive attention

\[\mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + \mathbf{k}^{(r)}_j)\]

or multiplicative attention

\[\mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j.\]

If the graph has multi-dimensional edge features \(\mathbf{e}^{(r)}_{i,j}\), the attention logits \(\mathbf{a}^{(r)}_{i,j}\) for each relation type \(r\) are computed as

\[\mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + \mathbf{k}^{(r)}_j + \mathbf{W}_2^{(r)}\mathbf{e}^{(r)}_{i,j})\]

or

\[\mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j \cdot \mathbf{W}_2^{(r)} \mathbf{e}^{(r)}_{i,j},\]

respectively. The attention coefficients \(\alpha^{(r)}_{i,j}\) for each relation type \(r\) are then obtained via two different attention mechanisms: The within-relation attention mechanism

\[\alpha^{(r)}_{i,j} = \frac{\exp(\mathbf{a}^{(r)}_{i,j})} {\sum_{k \in \mathcal{N}_r(i)} \exp(\mathbf{a}^{(r)}_{i,k})}\]

or the across-relation attention mechanism

\[\alpha^{(r)}_{i,j} = \frac{\exp(\mathbf{a}^{(r)}_{i,j})} {\sum_{r^{\prime} \in \mathcal{R}} \sum_{k \in \mathcal{N}_{r^{\prime}}(i)} \exp(\mathbf{a}^{(r^{\prime})}_{i,k})}\]

where \(\mathcal{R}\) denotes the set of relations, i.e. edge types. Edge type needs to be a one-dimensional torch.long tensor which stores a relation identifier \(\in \{ 0, \ldots, |\mathcal{R}| - 1\}\) for each edge.

To enhance the discriminative power of attention-based GNNs, this layer further implements four different cardinality preservation options as proposed in the “Improving Attention Mechanism in Graph Neural Networks via Cardinality Preservation” paper:

\[ \begin{align}\begin{aligned}\text{additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j + \mathcal{W} \odot \sum_{j \in \mathcal{N}_r(i)} \mathbf{x}^{(r)}_j\\\text{scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= \psi(|\mathcal{N}_r(i)|) \odot \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j\\\text{f-additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= \sum_{j \in \mathcal{N}_r(i)} (\alpha^{(r)}_{i,j} + 1) \cdot \mathbf{x}^{(r)}_j\\\text{f-scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= |\mathcal{N}_r(i)| \odot \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j\end{aligned}\end{align} \]
  • If attention_mode="additive-self-attention" and concat=True, the layer outputs heads * out_channels features for each node.

  • If attention_mode="multiplicative-self-attention" and concat=True, the layer outputs heads * dim * out_channels features for each node.

  • If attention_mode="additive-self-attention" and concat=False, the layer outputs out_channels features for each node.

  • If attention_mode="multiplicative-self-attention" and concat=False, the layer outputs dim * out_channels features for each node.

Please make sure to set the in_channels argument of the next layer accordingly if more than one instance of this layer is used.

Note

For an example of using RGATConv, see examples/rgat.py.

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • num_relations (int) – Number of relations.

  • num_bases (int, optional) – If set, this layer will use the basis-decomposition regularization scheme where num_bases denotes the number of bases to use. (default: None)

  • num_blocks (int, optional) – If set, this layer will use the block-diagonal-decomposition regularization scheme where num_blocks denotes the number of blocks to use. (default: None)

  • mod (str, optional) – The cardinality preservation option to use. ("additive", "scaled", "f-additive", "f-scaled", None). (default: None)

  • attention_mechanism (str, optional) – The attention mechanism to use ("within-relation", "across-relation"). (default: "across-relation")

  • attention_mode (str, optional) – The mode to calculate attention logits. ("additive-self-attention", "multiplicative-self-attention"). (default: "additive-self-attention")

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • dim (int) – Number of dimensions for query and key kernels. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default: None)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_type: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None, return_attention_weights=None)[source]
Parameters
  • x (Tensor) – The input node features. Can be either a [num_nodes, in_channels] node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings).

  • edge_index (LongTensor or SparseTensor) – The edge indices.

  • edge_type – The one-dimensional relation type/index for each edge in edge_index. Should be only None in case edge_index is of type torch_sparse.tensor.SparseTensor. (default: None)

  • edge_attr (Tensor, optional) – Edge feature matrix. (default: None)

  • return_attention_weights (bool, optional) – If set to True, will additionally return the tuple (edge_index, attention_weights), holding the computed attention weights for each edge. (default: None)

class SignedConv(in_channels: int, out_channels: int, first_aggr: bool, bias: bool = True, **kwargs)[source]

The signed graph convolutional operator from the “Signed Graph Convolutional Network” paper

\[ \begin{align}\begin{aligned}\mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w , \mathbf{x}_v \right]\\\mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{neg})} \left[ \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w , \mathbf{x}_v \right]\end{aligned}\end{align} \]

if first_aggr is set to True, and

\[ \begin{align}\begin{aligned}\mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{pos})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{neg})}, \mathbf{x}_v^{(\textrm{pos})} \right]\\\mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{neg})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{pos})}, \mathbf{x}_v^{(\textrm{neg})} \right]\end{aligned}\end{align} \]

otherwise. In case first_aggr is False, the layer expects x to be a tensor where x[:, :in_channels] denotes the positive node features \(\mathbf{X}^{(\textrm{pos})}\) and x[:, in_channels:] denotes the negative node features \(\mathbf{X}^{(\textrm{neg})}\).

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • first_aggr (bool) – Denotes which aggregation formula to use.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))\) if bipartite, positive edge indices \((2, |\mathcal{E}^{(+)}|)\), negative edge indices \((2, |\mathcal{E}^{(-)}|)\)

  • outputs: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], pos_edge_index: Union[Tensor, SparseTensor], neg_edge_index: Union[Tensor, SparseTensor])[source]
class DNAConv(channels: int, heads: int = 1, groups: int = 1, dropout: float = 0.0, cached: bool = False, normalize: bool = True, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]

The dynamic neighborhood aggregation operator from the “Just Jump: Towards Dynamic Neighborhood Aggregation in Graph Neural Networks” paper

\[\mathbf{x}_v^{(t)} = h_{\mathbf{\Theta}}^{(t)} \left( \mathbf{x}_{v \leftarrow v}^{(t)}, \left\{ \mathbf{x}_{v \leftarrow w}^{(t)} : w \in \mathcal{N}(v) \right\} \right)\]

based on (multi-head) dot-product attention

\[\mathbf{x}_{v \leftarrow w}^{(t)} = \textrm{Attention} \left( \mathbf{x}^{(t-1)}_v \, \mathbf{\Theta}_Q^{(t)}, [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_K^{(t)}, \, [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_V^{(t)} \right)\]

with \(\mathbf{\Theta}_Q^{(t)}, \mathbf{\Theta}_K^{(t)}, \mathbf{\Theta}_V^{(t)}\) denoting (grouped) projection matrices for query, key and value information, respectively. \(h^{(t)}_{\mathbf{\Theta}}\) is implemented as a non-trainable version of torch_geometric.nn.conv.GCNConv.

Note

In contrast to other layers, this operator expects node features as shape [num_nodes, num_layers, channels].

Parameters
  • channels (int) – Size of each input/output sample.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • groups (int, optional) – Number of groups to use for all linear projections. (default: 1)

  • dropout (float, optional) – Dropout probability of attention coefficients. (default: 0.)

  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • normalize (bool, optional) – Whether to add self-loops and apply symmetric normalization. (default: True)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, L, F)\) where \(L\) is the number of layers, edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F)\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
Parameters

x – The input node features of shape [num_nodes, num_layers, channels].

class PointNetConv(local_nn: Optional[Callable] = None, global_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs)[source]

The PointNet set layer from the “PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation” and “PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space” papers

\[\mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \mathbf{p}_j - \mathbf{p}_i) \right),\]

where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote neural networks, i.e. MLPs, and \(\mathbf{P} \in \mathbb{R}^{N \times D}\) defines the position of each point.

Parameters
  • local_nn (torch.nn.Module, optional) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features x and relative spatial coordinates pos_j - pos_i of shape [-1, in_channels + num_dimensions] to shape [-1, out_channels], e.g., defined by torch.nn.Sequential. (default: None)

  • global_nn (torch.nn.Module, optional) – A neural network \(\gamma_{\mathbf{\Theta}}\) that maps aggregated node features of shape [-1, out_channels] to shape [-1, final_out_channels], e.g., defined by torch.nn.Sequential. (default: None)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, positions \((|\mathcal{V}|, 3)\) or \(((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, None, Tuple[Optional[Tensor], Optional[Tensor]]], pos: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor]) Tensor[source]
PointConv

alias of PointNetConv

class GMMConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, dim: int, kernel_size: int, separate_gaussians: bool = False, aggr: str = 'mean', root_weight: bool = True, bias: bool = True, **kwargs)[source]

The gaussian mixture model convolutional operator from the “Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs” paper

\[\mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \frac{1}{K} \sum_{k=1}^K \mathbf{w}_k(\mathbf{e}_{i,j}) \odot \mathbf{\Theta}_k \mathbf{x}_j,\]

where

\[\mathbf{w}_k(\mathbf{e}) = \exp \left( -\frac{1}{2} {\left( \mathbf{e} - \mathbf{\mu}_k \right)}^{\top} \Sigma_k^{-1} \left( \mathbf{e} - \mathbf{\mu}_k \right) \right)\]

denotes a weighting function based on trainable mean vector \(\mathbf{\mu}_k\) and diagonal covariance matrix \(\mathbf{\Sigma}_k\).

Note

The edge attribute \(\mathbf{e}_{ij}\) is usually given by \(\mathbf{e}_{ij} = \mathbf{p}_j - \mathbf{p}_i\), where \(\mathbf{p}_i\) denotes the position of node \(i\) (see torch_geometric.transform.Cartesian).

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • dim (int) – Pseudo-coordinate dimensionality.

  • kernel_size (int) – Number of kernels \(K\).

  • separate_gaussians (bool, optional) – If set to True, will learn separate GMMs for every pair of input and output channel, inspired by traditional CNNs. (default: False)

  • aggr (string, optional) – The aggregation operator to use ("add", "mean", "max"). (default: "mean")

  • root_weight (bool, optional) – If set to False, the layer will not add transformed root node features to the output. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None)[source]
class SplineConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, dim: int, kernel_size: Union[int, List[int]], is_open_spline: bool = True, degree: int = 1, aggr: str = 'mean', root_weight: bool = True, bias: bool = True, **kwargs)[source]

The spline-based convolutional operator from the “SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels” paper

\[\mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),\]

where \(h_{\mathbf{\Theta}}\) denotes a kernel function defined over the weighted B-Spline tensor product basis.

Note

Pseudo-coordinates must lay in the fixed interval \([0, 1]\) for this method to work as intended.

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • dim (int) – Pseudo-coordinate dimensionality.

  • kernel_size (int or [int]) – Size of the convolving kernel.

  • is_open_spline (bool or [bool], optional) – If set to False, the operator will use a closed B-spline basis in this dimension. (default True)

  • degree (int, optional) – B-spline basis degrees. (default: 1)

  • aggr (string, optional) – The aggregation operator to use ("add", "mean", "max"). (default: "mean")

  • root_weight (bool, optional) – If set to False, the layer will not add transformed root node features to the output. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None) Tensor[source]
class NNConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, nn: Callable, aggr: str = 'add', root_weight: bool = True, bias: bool = True, **kwargs)[source]

The continuous kernel-based convolutional operator from the “Neural Message Passing for Quantum Chemistry” paper. This convolution is also known as the edge-conditioned convolution from the “Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs” paper (see torch_geometric.nn.conv.ECConv for an alias):

\[\mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),\]

where \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. a MLP.

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps edge features edge_attr of shape [-1, num_edge_features] to shape [-1, in_channels * out_channels], e.g., defined by torch.nn.Sequential.

  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "add")

  • root_weight (bool, optional) – If set to False, the layer will not add the transformed root node features to the output. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None) Tensor[source]
ECConv

alias of NNConv

class CGConv(channels: Union[int, Tuple[int, int]], dim: int = 0, aggr: str = 'add', batch_norm: bool = False, bias: bool = True, **kwargs)[source]

The crystal graph convolutional operator from the “Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties” paper

\[\mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \sigma \left( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right) \odot g \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s \right)\]

where \(\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{i,j} ]\) denotes the concatenation of central node features, neighboring node features and edge features. In addition, \(\sigma\) and \(g\) denote the sigmoid and softplus functions, respectively.

Parameters
  • channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities.

  • dim (int, optional) – Edge feature dimensionality. (default: 0)

  • aggr (string, optional) – The aggregation operator to use ("add", "mean", "max"). (default: "add")

  • batch_norm (bool, optional) – If set to True, will make use of batch normalization. (default: False)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F)\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F)\) or \((|\mathcal{V_t}|, F_{t})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None) Tensor[source]
class EdgeConv(nn: Callable, aggr: str = 'max', **kwargs)[source]

The edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}}(\mathbf{x}_i \, \Vert \, \mathbf{x}_j - \mathbf{x}_i),\]

where \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. a MLP.

Parameters
  • nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps pair-wise concatenated node features x of shape [-1, 2 * in_channels] to shape [-1, out_channels], e.g., defined by torch.nn.Sequential.

  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "max")

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor]) Tensor[source]
class DynamicEdgeConv(nn: Callable, k: int, aggr: str = 'max', num_workers: int = 1, **kwargs)[source]

The dynamic edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper (see torch_geometric.nn.conv.EdgeConv), where the graph is dynamically constructed using nearest neighbors in the feature space.

Parameters
  • nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps pair-wise concatenated node features x of shape :obj:`[-1, 2 * in_channels] to shape [-1, out_channels], e.g. defined by torch.nn.Sequential.

  • k (int) – Number of nearest neighbors.

  • aggr (string) – The aggregation operator to use ("add", "mean", "max"). (default: "max")

  • num_workers (int) – Number of workers to use for k-NN computation. Has no effect in case batch is not None, or the input lies on the GPU. (default: 1)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))\) if bipartite, batch vector \((|\mathcal{V}|)\) or \(((|\mathcal{V}|), (|\mathcal{V}|))\) if bipartite (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], batch: Union[Tensor, None, Tuple[Tensor, Tensor]] = None) Tensor[source]
class XConv(in_channels: int, out_channels: int, dim: int, kernel_size: int, hidden_channels: Optional[int] = None, dilation: int = 1, bias: bool = True, num_workers: int = 1)[source]

The convolutional operator on \(\mathcal{X}\)-transformed points from the “PointCNN: Convolution On X-Transformed Points” paper

\[\mathbf{x}^{\prime}_i = \mathrm{Conv}\left(\mathbf{K}, \gamma_{\mathbf{\Theta}}(\mathbf{P}_i - \mathbf{p}_i) \times \left( h_\mathbf{\Theta}(\mathbf{P}_i - \mathbf{p}_i) \, \Vert \, \mathbf{x}_i \right) \right),\]

where \(\mathbf{K}\) and \(\mathbf{P}_i\) denote the trainable filter and neighboring point positions of \(\mathbf{x}_i\), respectively. \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) describe neural networks, i.e. MLPs, where \(h_{\mathbf{\Theta}}\) individually lifts each point into a higher-dimensional space, and \(\gamma_{\mathbf{\Theta}}\) computes the \(\mathcal{X}\)- transformation matrix based on all points in a neighborhood.

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • dim (int) – Point cloud dimensionality.

  • kernel_size (int) – Size of the convolving kernel, i.e. number of neighbors including self-loops.

  • hidden_channels (int, optional) – Output size of \(h_{\mathbf{\Theta}}\), i.e. dimensionality of lifted points. If set to None, will be automatically set to in_channels / 4. (default: None)

  • dilation (int, optional) – The factor by which the neighborhood is extended, from which kernel_size neighbors are then uniformly sampled. Can be interpreted as the dilation rate of classical convolutional operators. (default: 1)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • num_workers (int) – Number of workers to use for k-NN computation. Has no effect in case batch is not None, or the input lies on the GPU. (default: 1)

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), positions \((|\mathcal{V}|, D)\), batch vector \((|\mathcal{V}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, pos: Tensor, batch: Optional[Tensor] = None)[source]
class PPFConv(local_nn: Optional[Callable] = None, global_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs)[source]

The PPFNet operator from the “PPFNet: Global Context Aware Local Features for Robust 3D Point Matching” paper

\[\mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \| \mathbf{d_{j,i}} \|, \angle(\mathbf{n}_i, \mathbf{d_{j,i}}), \angle(\mathbf{n}_j, \mathbf{d_{j,i}}), \angle(\mathbf{n}_i, \mathbf{n}_j) \right)\]

where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote neural networks, .i.e. MLPs, which takes in node features and torch_geometric.transforms.PointPairFeatures.

Parameters
  • local_nn (torch.nn.Module, optional) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features x and relative spatial coordinates pos_j - pos_i of shape [-1, in_channels + num_dimensions] to shape [-1, out_channels], e.g., defined by torch.nn.Sequential. (default: None)

  • global_nn (torch.nn.Module, optional) – A neural network \(\gamma_{\mathbf{\Theta}}\) that maps aggregated node features of shape [-1, out_channels] to shape [-1, final_out_channels], e.g., defined by torch.nn.Sequential. (default: None)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, positions \((|\mathcal{V}|, 3)\) or \(((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))\) if bipartite, point normals \((|\mathcal{V}, 3)\) or \(((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, None, Tuple[Optional[Tensor], Optional[Tensor]]], pos: Union[Tensor, Tuple[Tensor, Tensor]], normal: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor]) Tensor[source]
class FeaStConv(in_channels: int, out_channels: int, heads: int = 1, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]

The (translation-invariant) feature-steered convolutional operator from the “FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis” paper

\[\mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \sum_{h=1}^H q_h(\mathbf{x}_i, \mathbf{x}_j) \mathbf{W}_h \mathbf{x}_j\]

with \(q_h(\mathbf{x}_i, \mathbf{x}_j) = \mathrm{softmax}_j (\mathbf{u}_h^{\top} (\mathbf{x}_j - \mathbf{x}_i) + c_h)\), where \(H\) denotes the number of attention heads, and \(\mathbf{W}_h\), \(\mathbf{u}_h\) and \(c_h\) are trainable parameters.

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • heads (int, optional) – Number of attention heads \(H\). (default: 1)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor]) Tensor[source]
class PointTransformerConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, pos_nn: Optional[Callable] = None, attn_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs)[source]

The Point Transformer layer from the “Point Transformer” paper

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j} \left(\mathbf{W}_3 \mathbf{x}_j + \delta_{ij} \right),\]

where the attention coefficients \(\alpha_{i,j}\) and positional embedding \(\delta_{ij}\) are computed as

\[\alpha_{i,j}= \textrm{softmax} \left( \gamma_\mathbf{\Theta} (\mathbf{W}_1 \mathbf{x}_i - \mathbf{W}_2 \mathbf{x}_j + \delta_{i,j}) \right)\]

and

\[\delta_{i,j}= h_{\mathbf{\Theta}}(\mathbf{p}_i - \mathbf{p}_j),\]

with \(\gamma_\mathbf{\Theta}\) and \(h_\mathbf{\Theta}\) denoting neural networks, i.e. MLPs, and \(\mathbf{P} \in \mathbb{R}^{N \times D}\) defines the position of each point.

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • pos_nn (torch.nn.Module, optional) – A neural network \(h_\mathbf{\Theta}\) which maps relative spatial coordinates pos_j - pos_i of shape [-1, 3] to shape [-1, out_channels]. Will default to a torch.nn.Linear transformation if not further specified. (default: None)

  • attn_nn (torch.nn.Module, optional) – A neural network \(\gamma_\mathbf{\Theta}\) which maps transformed node features of shape [-1, out_channels] to shape [-1, out_channels]. (default: None)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, positions \((|\mathcal{V}|, 3)\) or \(((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], pos: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor]) Tensor[source]
class HypergraphConv(in_channels, out_channels, use_attention=False, heads=1, concat=True, negative_slope=0.2, dropout=0, bias=True, **kwargs)[source]

The hypergraph convolutional operator from the “Hypergraph Convolution and Hypergraph Attention” paper

\[\mathbf{X}^{\prime} = \mathbf{D}^{-1} \mathbf{H} \mathbf{W} \mathbf{B}^{-1} \mathbf{H}^{\top} \mathbf{X} \mathbf{\Theta}\]

where \(\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}\) is the incidence matrix, \(\mathbf{W} \in \mathbb{R}^M\) is the diagonal hyperedge weight matrix, and \(\mathbf{D}\) and \(\mathbf{B}\) are the corresponding degree matrices.

For example, in the hypergraph scenario \(\mathcal{G} = (\mathcal{V}, \mathcal{E})\) with \(\mathcal{V} = \{ 0, 1, 2, 3 \}\) and \(\mathcal{E} = \{ \{ 0, 1, 2 \}, \{ 1, 2, 3 \} \}\), the hyperedge_index is represented as:

hyperedge_index = torch.tensor([
    [0, 1, 2, 1, 2, 3],
    [0, 0, 0, 1, 1, 1],
])
Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • use_attention (bool, optional) – If set to True, attention will be added to this layer. (default: False)

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), hyperedge indices \((|\mathcal{V}|, |\mathcal{E}|)\), hyperedge weights \((|\mathcal{E}|)\) (optional) hyperedge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, hyperedge_index: Tensor, hyperedge_weight: Optional[Tensor] = None, hyperedge_attr: Optional[Tensor] = None) Tensor[source]
Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • hyperedge_index (LongTensor) – The hyperedge indices, i.e. the sparse incidence matrix \(\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}\) mapping from nodes to edges.

  • hyperedge_weight (Tensor, optional) – Hyperedge weights \(\mathbf{W} \in \mathbb{R}^M\). (default: None)

  • hyperedge_attr (Tensor, optional) – Hyperedge feature matrix in \(\mathbb{R}^{M \times F}\). These features only need to get passed in case use_attention=True. (default: None)

class LEConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, bias: bool = True, **kwargs)[source]

The local extremum graph neural network operator from the “ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations” paper, which finds the importance of nodes with respect to their neighbors using the difference operator:

\[\mathbf{x}^{\prime}_i = \mathbf{x}_i \cdot \mathbf{\Theta}_1 + \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot (\mathbf{\Theta}_2 \mathbf{x}_i - \mathbf{\Theta}_3 \mathbf{x}_j)\]

where \(e_{j,i}\) denotes the edge weight from source node j to target node i (default: 1)

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True).

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
class PNAConv(in_channels: int, out_channels: int, aggregators: List[str], scalers: List[str], deg: Tensor, edge_dim: Optional[int] = None, towers: int = 1, pre_layers: int = 1, post_layers: int = 1, divide_input: bool = False, act: Optional[Union[str, Callable]] = 'relu', act_kwargs: Optional[Dict[str, Any]] = None, train_norm: bool = False, **kwargs)[source]

The Principal Neighbourhood Aggregation graph convolution operator from the “Principal Neighbourhood Aggregation for Graph Nets” paper

\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \underset{j \in \mathcal{N}(i)}{\bigoplus} h_{\mathbf{\Theta}} \left( \mathbf{x}_i, \mathbf{x}_j \right) \right)\]

with

\[\begin{split}\bigoplus = \underbrace{\begin{bmatrix} 1 \\ S(\mathbf{D}, \alpha=1) \\ S(\mathbf{D}, \alpha=-1) \end{bmatrix} }_{\text{scalers}} \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min \end{bmatrix}}_{\text{aggregators}},\end{split}\]

where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote MLPs.

Note

For an example of using PNAConv, see examples/pna.py.

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • aggregators (list of str) – Set of aggregation function identifiers, namely "sum", "mean", "min", "max", "var" and "std".

  • scalers (list of str) – Set of scaling function identifiers, namely "identity", "amplification", "attenuation", "linear" and "inverse_linear".

  • deg (Tensor) – Histogram of in-degrees of nodes in the training set, used by scalers to normalize.

  • edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default None)

  • towers (int, optional) – Number of towers (default: 1).

  • pre_layers (int, optional) – Number of transformation layers before aggregation (default: 1).

  • post_layers (int, optional) – Number of transformation layers after aggregation (default: 1).

  • divide_input (bool, optional) – Whether the input features should be split between towers or not (default: False).

  • act (str or Callable, optional) – Pre- and post-layer activation function to use. (default: "relu")

  • act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by act. (default: None)

  • train_norm (bool, optional) – are trainable. (default: False)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None) Tensor[source]
static get_degree_histogram(loader) Tensor[source]
class ClusterGCNConv(in_channels: int, out_channels: int, diag_lambda: float = 0.0, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]

The ClusterGCN graph convolutional operator from the “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks” paper

\[\mathbf{X}^{\prime} = \left( \mathbf{\hat{A}} + \lambda \cdot \textrm{diag}(\mathbf{\hat{A}}) \right) \mathbf{X} \mathbf{W}_1 + \mathbf{X} \mathbf{W}_2\]

where \(\mathbf{\hat{A}} = {(\mathbf{D} + \mathbf{I})}^{-1}(\mathbf{A} + \mathbf{I})\).

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • diag_lambda (float, optional) – Diagonal enhancement value \(\lambda\). (default: 0.)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor]) Tensor[source]
class GENConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, aggr: Optional[Union[str, List[str], Aggregation]] = 'softmax', t: float = 1.0, learn_t: bool = False, p: float = 1.0, learn_p: bool = False, msg_norm: bool = False, learn_msg_scale: bool = False, norm: str = 'batch', num_layers: int = 2, expansion: int = 2, eps: float = 1e-07, bias: bool = False, edge_dim: Optional[int] = None, **kwargs)[source]

The GENeralized Graph Convolution (GENConv) from the “DeeperGCN: All You Need to Train Deeper GCNs” paper. Supports SoftMax & PowerMean aggregation. The message construction is:

\[\mathbf{x}_i^{\prime} = \mathrm{MLP} \left( \mathbf{x}_i + \mathrm{AGG} \left( \left\{ \mathrm{ReLU} \left( \mathbf{x}_j + \mathbf{e_{ji}} \right) +\epsilon : j \in \mathcal{N}(i) \right\} \right) \right)\]

Note

For an example of using GENConv, see examples/ogbn_proteins_deepgcn.py.

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • aggr (string or Aggregation, optional) – The aggregation scheme to use. Any aggregation of torch_geometric.nn.aggr can be used, ("softmax", "powermean", "add", "mean", max). (default: "softmax")

  • t (float, optional) – Initial inverse temperature for softmax aggregation. (default: 1.0)

  • learn_t (bool, optional) – If set to True, will learn the value t for softmax aggregation dynamically. (default: False)

  • p (float, optional) – Initial power for power mean aggregation. (default: 1.0)

  • learn_p (bool, optional) – If set to True, will learn the value p for power mean aggregation dynamically. (default: False)

  • msg_norm (bool, optional) – If set to True, will use message normalization. (default: False)

  • learn_msg_scale (bool, optional) – If set to True, will learn the scaling factor of message normalization. (default: False)

  • norm (str, optional) – Norm layer of MLP layers ("batch", "layer", "instance") (default: batch)

  • num_layers (int, optional) – The number of MLP layers. (default: 2)

  • expansion (int, optional) – The expansion factor of hidden channels in MLP layers. (default: 2)

  • eps (float, optional) – The epsilon value of the message construction function. (default: 1e-7)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • edge_dim (int, optional) – Edge feature dimensionality. If set to None, Edge feature dimensionality is expected to match the out_channels. Other-wise, edge features are linearly transformed to match out_channels of node feature dimensionality. (default: None)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.GenMessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge attributes \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None) Tensor[source]
class GCN2Conv(channels: int, alpha: float, theta: Optional[float] = None, layer: Optional[int] = None, shared_weights: bool = True, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, **kwargs)[source]

The graph convolutional operator with initial residual connections and identity mapping (GCNII) from the “Simple and Deep Graph Convolutional Networks” paper

\[\mathbf{X}^{\prime} = \left( (1 - \alpha) \mathbf{\hat{P}}\mathbf{X} + \alpha \mathbf{X^{(0)}}\right) \left( (1 - \beta) \mathbf{I} + \beta \mathbf{\Theta} \right)\]

with \(\mathbf{\hat{P}} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\), where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix, and \(\mathbf{X}^{(0)}\) being the initial feature representation. Here, \(\alpha\) models the strength of the initial residual connection, while \(\beta\) models the strength of the identity mapping. The adjacency matrix can include other values than 1 representing edge weights via the optional edge_weight tensor.

Parameters
  • channels (int) – Size of each input and output sample.

  • alpha (float) – The strength of the initial residual connection \(\alpha\).

  • theta (float, optional) – The hyperparameter \(\theta\) to compute the strength of the identity mapping \(\beta = \log \left( \frac{\theta}{\ell} + 1 \right)\). (default: None)

  • layer (int, optional) – The layer \(\ell\) in which this module is executed. (default: None)

  • shared_weights (bool, optional) – If set to False, will use different weight matrices for the smoothed representation and the initial residual (“GCNII*”). (default: True)

  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • normalize (bool, optional) – Whether to add self-loops and apply symmetric normalization. (default: True)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F)\), initial node features \((|\mathcal{V}|, F)\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F)\)

reset_parameters()[source]
forward(x: Tensor, x_0: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
class PANConv(in_channels: int, out_channels: int, filter_size: int, **kwargs)[source]

The path integral based convolutional operator from the “Path Integral Based Convolution and Pooling for Graph Neural Networks” paper

\[\mathbf{X}^{\prime} = \mathbf{M} \mathbf{X} \mathbf{W}\]

where \(\mathbf{M}\) denotes the normalized and learned maximal entropy transition (MET) matrix that includes neighbors up to filter_size hops:

\[\mathbf{M} = \mathbf{Z}^{-1/2} \sum_{n=0}^L e^{-\frac{E(n)}{T}} \mathbf{A}^n \mathbf{Z}^{-1/2}\]
Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • filter_size (int) – The filter size \(L\).

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\),

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor]) Tuple[Tensor, SparseTensor][source]
panentropy(adj_t: SparseTensor, dtype: Optional[int] = None) SparseTensor[source]
class WLConv[source]

The Weisfeiler Lehman operator from the “A Reduction of a Graph to a Canonical Form and an Algebra Arising During this Reduction” paper, which iteratively refines node colorings:

\[\mathbf{x}^{\prime}_i = \textrm{hash} \left( \mathbf{x}_i, \{ \mathbf{x}_j \colon j \in \mathcal{N}(i) \} \right)\]
Shapes:
  • input: node coloring \((|\mathcal{V}|, F_{in})\) (one-hot encodings) or \((|\mathcal{V}|)\) (integer-based), edge indices \((2, |\mathcal{E}|)\)

  • output: node coloring \((|\mathcal{V}|)\) (integer-based)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor]) Tensor[source]
histogram(x: Tensor, batch: Optional[Tensor] = None, norm: bool = False) Tensor[source]

Given a node coloring x, computes the color histograms of the respective graphs (separated by batch).

class WLConvContinuous(**kwargs)[source]

The Weisfeiler Lehman operator from the “Wasserstein Weisfeiler-Lehman Graph Kernels” paper. Refinement is done though a degree-scaled mean aggregation and works on nodes with continuous attributes:

\[\mathbf{x}^{\prime}_i = \frac{1}{2}\big(\mathbf{x}_i + \frac{1}{\textrm{deg}(i)} \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{x}_j \big)\]

where \(e_{j,i}\) denotes the edge weight from source node j to target node i (default: 1)

Parameters

**kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F)\) or \(((|\mathcal{V_s}|, F), (|\mathcal{V_t}|, F))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F)\) or \((|\mathcal{V}_t|, F)\) if bipartite

forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Tensor, edge_weight: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None) Tensor[source]
class FiLMConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, num_relations: int = 1, nn: Optional[Callable] = None, act: Optional[Callable] = ReLU(), aggr: str = 'mean', **kwargs)[source]

The FiLM graph convolutional operator from the “GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation” paper

\[\mathbf{x}^{\prime}_i = \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}(i)} \sigma \left( \boldsymbol{\gamma}_{r,i} \odot \mathbf{W}_r \mathbf{x}_j + \boldsymbol{\beta}_{r,i} \right)\]

where \(\boldsymbol{\beta}_{r,i}, \boldsymbol{\gamma}_{r,i} = g(\mathbf{x}_i)\) with \(g\) being a single linear layer by default. Self-loops are automatically added to the input graph and represented as its own relation type.

Note

For an example of using FiLM, see examples/gcn.py.

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • num_relations (int, optional) – Number of relations. (default: 1)

  • nn (torch.nn.Module, optional) – The neural network \(g\) that maps node features x_i of shape [-1, in_channels] to shape [-1, 2 * out_channels]. If set to None, \(g\) will be implemented as a single linear layer. (default: None)

  • act (callable, optional) – Activation function \(\sigma\). (default: torch.nn.ReLU())

  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "mean")

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge types \((|\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor], edge_type: Optional[Tensor] = None) Tensor[source]
class SuperGATConv(in_channels: int, out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, bias: bool = True, attention_type: str = 'MX', neg_sample_ratio: float = 0.5, edge_sample_ratio: float = 1.0, is_undirected: bool = False, **kwargs)[source]

The self-supervised graph attentional operator from the “How to Find Your Friendly Neighborhood: Graph Attention Design with Self-Supervision” paper

\[\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},\]

where the two types of attention \(\alpha_{i,j}^{\mathrm{MX\ or\ SD}}\) are computed as:

\[ \begin{align}\begin{aligned}\alpha_{i,j}^{\mathrm{MX\ or\ SD}} &= \frac{ \exp\left(\mathrm{LeakyReLU}\left( e_{i,j}^{\mathrm{MX\ or\ SD}} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( e_{i,k}^{\mathrm{MX\ or\ SD}} \right)\right)}\\e_{i,j}^{\mathrm{MX}} &= \mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] \cdot \sigma \left( \left( \mathbf{\Theta}\mathbf{x}_i \right)^{\top} \mathbf{\Theta}\mathbf{x}_j \right)\\e_{i,j}^{\mathrm{SD}} &= \frac{ \left( \mathbf{\Theta}\mathbf{x}_i \right)^{\top} \mathbf{\Theta}\mathbf{x}_j }{ \sqrt{d} }\end{aligned}\end{align} \]

The self-supervised task is a link prediction using the attention values as input to predict the likelihood \(\phi_{i,j}^{\mathrm{MX\ or\ SD}}\) that an edge exists between nodes:

\[ \begin{align}\begin{aligned}\phi_{i,j}^{\mathrm{MX}} &= \sigma \left( \left( \mathbf{\Theta}\mathbf{x}_i \right)^{\top} \mathbf{\Theta}\mathbf{x}_j \right)\\\phi_{i,j}^{\mathrm{SD}} &= \sigma \left( \frac{ \left( \mathbf{\Theta}\mathbf{x}_i \right)^{\top} \mathbf{\Theta}\mathbf{x}_j }{ \sqrt{d} } \right)\end{aligned}\end{align} \]

Note

For an example of using SuperGAT, see examples/super_gat.py.

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • attention_type (string, optional) – Type of attention to use. ('MX', 'SD'). (default: 'MX')

  • neg_sample_ratio (float, optional) – The ratio of the number of sampled negative edges to the number of positive edges. (default: 0.5)

  • edge_sample_ratio (float, optional) – The ratio of samples to use for training among the number of training edges. (default: 1.0)

  • is_undirected (bool, optional) – Whether the input graph is undirected. If not given, will be automatically computed with the input graph when negative sampling is performed. (default: False)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), negative edge indices \((2, |\mathcal{E}^{(-)}|)\) (optional)

  • output: node features \((|\mathcal{V}|, H * F_{out})\)

att_x: Optional[Tensor]
att_y: Optional[Tensor]
reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], neg_edge_index: Optional[Tensor] = None, batch: Optional[Tensor] = None) Tensor[source]
Parameters

neg_edge_index (Tensor, optional) – The negative edges to train against. If not given, uses negative sampling to calculate negative edges. (default: None)

negative_sampling(edge_index: Tensor, num_nodes: int, batch: Optional[Tensor] = None) Tensor[source]
positive_sampling(edge_index: Tensor) Tensor[source]
get_attention(edge_index_i: Tensor, x_i: Tensor, x_j: Tensor, num_nodes: Optional[int], return_logits: bool = False) Tensor[source]
get_attention_loss() Tensor[source]

Compute the self-supervised graph attention loss.

class FAConv(channels: int, eps: float = 0.1, dropout: float = 0.0, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, **kwargs)[source]

The Frequency Adaptive Graph Convolution operator from the “Beyond Low-Frequency Information in Graph Convolutional Networks” paper

\[\mathbf{x}^{\prime}_i= \epsilon \cdot \mathbf{x}^{(0)}_i + \sum_{j \in \mathcal{N}(i)} \frac{\alpha_{i,j}}{\sqrt{d_i d_j}} \mathbf{x}_{j}\]

where \(\mathbf{x}^{(0)}_i\) and \(d_i\) denote the initial feature representation and node degree of node \(i\), respectively. The attention coefficients \(\alpha_{i,j}\) are computed as

\[\mathbf{\alpha}_{i,j} = \textrm{tanh}(\mathbf{a}^{\top}[\mathbf{x}_i, \mathbf{x}_j])\]

based on the trainable parameter vector \(\mathbf{a}\).

Parameters
  • channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • eps (float, optional) – \(\epsilon\)-value. (default: 0.1)

  • dropout (float, optional) – Dropout probability of the normalized coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0).

  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\sqrt{d_i d_j}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • normalize (bool, optional) – Whether to add self-loops (if add_self_loops is True) and compute symmetric normalization coefficients on the fly. If set to False, edge_weight needs to be provided in the layer’s forward() method. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F)\), initial node features \((|\mathcal{V}|, F)\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F)\) or \(((|\mathcal{V}|, F), ((2, |\mathcal{E}|), (|\mathcal{E}|)))\) if return_attention_weights=True

reset_parameters()[source]
forward(x: Tensor, x_0: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None, return_attention_weights=None)[source]
Parameters

return_attention_weights (bool, optional) – If set to True, will additionally return the tuple (edge_index, attention_weights), holding the computed attention weights for each edge. (default: None)

class EGConv(in_channels: int, out_channels: int, aggregators: List[str] = ['symnorm'], num_heads: int = 8, num_bases: int = 4, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]

The Efficient Graph Convolution from the “Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions” paper.

Its node-wise formulation is given by:

\[\mathbf{x}_i^{\prime} = {\LARGE ||}_{h=1}^H \sum_{\oplus \in \mathcal{A}} \sum_{b = 1}^B w_{i, h, \oplus, b} \; \underset{j \in \mathcal{N}(i) \cup \{i\}}{\bigoplus} \mathbf{W}_b \mathbf{x}_{j}\]

with \(\mathbf{W}_b\) denoting a basis weight, \(\oplus\) denoting an aggregator, and \(w\) denoting per-vertex weighting coefficients across different heads, bases and aggregators.

EGC retains \(\mathcal{O}(|\mathcal{V}|)\) memory usage, making it a sensible alternative to GCNConv, SAGEConv or GINConv.

Note

For an example of using EGConv, see examples/egc.py.

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • aggregators (List[str], optional) – Aggregators to be used. Supported aggregators are "sum", "mean", "symnorm", "max", "min", "std", "var". Multiple aggregators can be used to improve the performance. (default: ["symnorm"])

  • num_heads (int, optional) – Number of heads \(H\) to use. Must have out_channels % num_heads == 0. It is recommended to set num_heads >= num_bases. (default: 8)

  • num_bases (int, optional) – Number of basis weights \(B\) to use. (default: 4)

  • cached (bool, optional) – If set to True, the layer will cache the computation of the edge index with added self loops on first execution, along with caching the calculation of the symmetric normalized edge weights if the "symnorm" aggregator is being used. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor]) Tensor[source]
class PDNConv(in_channels: int, out_channels: int, edge_dim: int, hidden_channels: int, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, **kwargs)[source]

The pathfinder discovery network convolutional operator from the “Pathfinder Discovery Networks for Neural Message Passing” paper

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(v) \cup \{i\}}f_{\Theta}(\textbf{e}_{(j,i)}) \cdot f_{\Omega}(\mathbf{x}_{j})\]

where \(z_{i,j}\) denotes the edge feature vector from source node \(j\) to target node \(i\), and \(\mathbf{x}_{j}\) denotes the node feature vector of node \(j\).

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • edge_dim (int) – Edge feature dimensionality.

  • hidden_channels (int) – Hidden edge feature dimensionality.

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • normalize (bool, optional) – Whether to add self-loops and compute symmetric normalization coefficients on the fly. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None) Tensor[source]
class GeneralConv(in_channels: Union[int, Tuple[int, int]], out_channels: Optional[int], in_edge_channels: Optional[int] = None, aggr: str = 'add', skip_linear: str = False, directed_msg: bool = True, heads: int = 1, attention: bool = False, attention_type: str = 'additive', l2_normalize: bool = False, bias: bool = True, **kwargs)[source]

A general GNN layer adapted from the “Design Space for Graph Neural Networks” paper.

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • in_edge_channels (int, optional) – Size of each input edge. (default: None)

  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "mean")

  • skip_linear (bool, optional) – Whether apply linear function in skip connection. (default: False)

  • directed_msg (bool, optional) – If message passing is directed; otherwise, message passing is bi-directed. (default: True)

  • heads (int, optional) – Number of message passing ensembles. If heads > 1, the GNN layer will output an ensemble of multiple messages. If attention is used (attention=True), this corresponds to multi-head attention. (default: 1)

  • attention (bool, optional) – Whether to add attention to message computation. (default: False)

  • attention_type (str, optional) – Type of attention: "additive", "dot_product". (default: "additive")

  • l2_normalize (bool, optional) – If set to True, output features will be \(\ell_2\)-normalized, i.e., \(\frac{\mathbf{x}^{\prime}_i} {\| \mathbf{x}^{\prime}_i \|_2}\). (default: False)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge attributes \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

reset_parameters()[source]
forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None) Tensor[source]
message_basic(x_i: Tensor, x_j: Tensor, edge_attr: Optional[Tensor])[source]
class HGTConv(in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Tuple[List[str], List[Tuple[str, str, str]]], heads: int = 1, group: str = 'sum', **kwargs)[source]

The Heterogeneous Graph Transformer (HGT) operator from the “Heterogeneous Graph Transformer” paper.

Note

For an example of using HGT, see examples/hetero/hgt_dblp.py.

Parameters
  • in_channels (int or Dict[str, int]) – Size of each input sample of every node type, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See torch_geometric.data.HeteroData.metadata() for more information.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • group (string, optional) – The aggregation scheme to use for grouping node embeddings generated by different relations. ("sum", "mean", "min", "max"). (default: "sum")

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

reset_parameters()[source]
forward(x_dict: Dict[str, Tensor], edge_index_dict: Union[Dict[Tuple[str, str, str], Tensor], Dict[Tuple[str, str, str], SparseTensor]]) Dict[str, Optional[Tensor]][source]
Parameters
  • x_dict (Dict[str, Tensor]) – A dictionary holding input node features for each individual node type.

  • edge_index_dict (Dict[str, Union[Tensor, SparseTensor]]) – A dictionary holding graph connectivity information for each individual edge type, either as a torch.LongTensor of shape [2, num_edges] or a torch_sparse.SparseTensor.

Return type

Dict[str, Optional[Tensor]] - The output node embeddings for each node type. In case a node type does not receive any message, its output will be set to None.

class HEATConv(in_channels: int, out_channels: int, num_node_types: int, num_edge_types: int, edge_type_emb_dim: int, edge_dim: int, edge_attr_emb_dim: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, root_weight: bool = True, bias: bool = True, **kwargs)[source]

The heterogeneous edge-enhanced graph attentional operator from the “Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction” paper, which enhances GATConv by:

  1. type-specific transformations of nodes of different types

  2. edge type and edge feature incorporation, in which edges are assumed to have different types but contain the same kind of attributes

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • num_node_types (int) – The number of node types.

  • num_edge_types (int) – The number of edge types.

  • edge_type_emb_dim (int) – The embedding size of edge types.

  • edge_dim (int) – Edge feature dimensionality.

  • edge_attr_emb_dim (int) – The embedding size of edge features.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • root_weight (bool, optional) – If set to False, the layer will not add transformed root node features to the output. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), node types \((|\mathcal{V}|)\), edge types \((|\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], node_type: Tensor, edge_type: Tensor, edge_attr: Optional[Tensor] = None) Tensor[source]
class HeteroConv(convs: Dict[Tuple[str, str, str], Module], aggr: Optional[str] = 'sum')[source]

A generic wrapper for computing graph convolution on heterogeneous graphs. This layer will pass messages from source nodes to target nodes based on the bipartite GNN layer given for a specific edge type. If multiple relations point to the same destination, their results will be aggregated according to aggr. In comparison to torch_geometric.nn.to_hetero(), this layer is especially useful if you want to apply different message passing modules for different edge types.

hetero_conv = HeteroConv({
    ('paper', 'cites', 'paper'): GCNConv(-1, 64),
    ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
    ('paper', 'written_by', 'author'): GATConv((-1, -1), 64),
}, aggr='sum')

out_dict = hetero_conv(x_dict, edge_index_dict)

print(list(out_dict.keys()))
>>> ['paper', 'author']
Parameters
  • convs (Dict[Tuple[str, str, str], Module]) – A dictionary holding a bipartite MessagePassing layer for each individual edge type.

  • aggr (string, optional) – The aggregation scheme to use for grouping node embeddings generated by different relations. ("sum", "mean", "min", "max", None). (default: "sum")

reset_parameters()[source]
forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Union[Tensor, SparseTensor]], *args_dict, **kwargs_dict) Dict[str, Tensor][source]
Parameters
  • x_dict (Dict[str, Tensor]) – A dictionary holding node feature information for each individual node type.

  • edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – A dictionary holding graph connectivity information for each individual edge type.

  • *args_dict (optional) – Additional forward arguments of invididual torch_geometric.nn.conv.MessagePassing layers.

  • **kwargs_dict (optional) – Additional forward arguments of individual torch_geometric.nn.conv.MessagePassing layers. For example, if a specific GNN layer at edge type edge_type expects edge attributes edge_attr as a forward argument, then you can pass them to forward() via edge_attr_dict = { edge_type: edge_attr }.

class HANConv(in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Tuple[List[str], List[Tuple[str, str, str]]], heads: int = 1, negative_slope=0.2, dropout: float = 0.0, **kwargs)[source]

The Heterogenous Graph Attention Operator from the “Heterogenous Graph Attention Network” paper.

Note

For an example of using HANConv, see examples/hetero/han_imdb.py.

Parameters
  • in_channels (int or Dict[str, int]) – Size of each input sample of every node type, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See torch_geometric.data.HeteroData.metadata() for more information.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

reset_parameters()[source]
forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Union[Tensor, SparseTensor]], return_semantic_attention_weights: bool = False) Union[Dict[str, Optional[Tensor]], Tuple[Dict[str, Optional[Tensor]], Dict[str, Optional[Tensor]]]][source]
Parameters
  • x_dict (Dict[str, Tensor]) – A dictionary holding input node features for each individual node type.

  • edge_index_dict (Dict[str, Union[Tensor, SparseTensor]]) – A dictionary holding graph connectivity information for each individual edge type, either as a torch.LongTensor of shape [2, num_edges] or a torch_sparse.SparseTensor.

  • return_semantic_attention_weights (bool, optional) – If set to True, will additionally return the semantic-level attention weights for each destination node type. (default: False)

class LGConv(normalize: bool = True, **kwargs)[source]

The Light Graph Convolution (LGC) operator from the “LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation” paper

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \frac{e_{j,i}}{\sqrt{\deg(i)\deg(j)}} \mathbf{x}_j\]
Parameters
Shapes:
  • input: node features \((|\mathcal{V}|, F)\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F)\)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
class MetaLayer(edge_model: Optional[Module] = None, node_model: Optional[Module] = None, global_model: Optional[Module] = None)[source]

A meta layer for building any kind of graph network, inspired by the “Relational Inductive Biases, Deep Learning, and Graph Networks” paper.

A graph network takes a graph as input and returns an updated graph as output (with same connectivity). The input graph has node features x, edge features edge_attr as well as global-level features u. The output graph has the same structure, but updated features.

Edge features, node features as well as global features are updated by calling the modules edge_model, node_model and global_model, respectively.

To allow for batch-wise graph processing, all callable functions take an additional argument batch, which determines the assignment of edges or nodes to their specific graphs.

Parameters
  • edge_model (Module, optional) – A callable which updates a graph’s edge features based on its source and target node features, its current edge features and its global features. (default: None)

  • node_model (Module, optional) – A callable which updates a graph’s node features based on its current node features, its graph connectivity, its edge features and its global features. (default: None)

  • global_model (Module, optional) – A callable which updates a graph’s global features based on its node features, its graph connectivity, its edge features and its current global features.

from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_scatter import scatter_mean
from torch_geometric.nn import MetaLayer

class EdgeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.edge_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))

    def forward(self, src, dest, edge_attr, u, batch):
        # src, dest: [E, F_x], where E is the number of edges.
        # edge_attr: [E, F_e]
        # u: [B, F_u], where B is the number of graphs.
        # batch: [E] with max entry B - 1.
        out = torch.cat([src, dest, edge_attr, u[batch]], 1)
        return self.edge_mlp(out)

class NodeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.node_mlp_1 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
        self.node_mlp_2 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))

    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        row, col = edge_index
        out = torch.cat([x[row], edge_attr], dim=1)
        out = self.node_mlp_1(out)
        out = scatter_mean(out, col, dim=0, dim_size=x.size(0))
        out = torch.cat([x, out, u[batch]], dim=1)
        return self.node_mlp_2(out)

class GlobalModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.global_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))

    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        out = torch.cat([u, scatter_mean(x, batch, dim=0)], dim=1)
        return self.global_mlp(out)

op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())
x, edge_attr, u = op(x, edge_index, edge_attr, u, batch)
reset_parameters()[source]
forward(x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None, u: Optional[Tensor] = None, batch: Optional[Tensor] = None) Tuple[Tensor, Optional[Tensor], Optional[Tensor]][source]

Aggregation Operators

Aggregation functions play an important role in the message passing framework and the readout functions of Graph Neural Networks. Specifically, many works in the literature (Hamilton et al. (2017), Xu et al. (2018), Corso et al. (2020), Li et al. (2020), Tailor et al. (2021)) demonstrate that the choice of aggregation functions contributes significantly to the representational power and performance of the model. For example, mean aggregation captures the distribution (or proportions) of elements, max aggregation proves to be advantageous to identify representative elements, and sum aggregation enables the learning of structural graph properties (Xu et al. (2018)). Recent works also show that using multiple aggregations (Corso et al. (2020), Tailor et al. (2021)) and learnable aggregations (Li et al. (2020)) can potentially provide substantial improvements. Another line of research studies optimization-based and implicitly-defined aggregations (Bartunov et al. (2022)).

To facilitate further experimentation and unify the concepts of aggregation within GNNs across both MessagePassing and global readouts, we have made the concept of Aggregation a first-class principle in PyG. As of now, PyG provides support for various aggregations — from simple ones (e.g., mean, max, sum), to advanced ones (e.g., median, var, std), learnable ones (e.g., SoftmaxAggregation, PowerMeanAggregation), and exotic ones (e.g., LSTMAggregation, SortAggregation, EquilibriumAggregation):

from torch_geometric.nn import aggr

# Simple aggregations:
mean_aggr = aggr.MeanAggregation()
max_aggr = aggr.MaxAggregation()

# Advanced aggregations:
median_aggr = aggr.MedianAggregation()

# Learnable aggregations:
softmax_aggr = aggr.SoftmaxAggregation(learn=True)
powermean_aggr = aggr.PowerMeanAggregation(learn=True)

# Exotic aggregations:
lstm_aggr = aggr.LSTMAggregation(in_channels=..., out_channels=...)
sort_aggr = aggr.SortAggregation(k=4)

We can then easily apply these aggregations over a batch of sets of potentially varying size. For this, an index vector defines the mapping from input elements to their location in the output:

# Feature matrix holding 1000 elements with 64 features each:
x = torch.randn(1000, 64)

# Randomly assign elements to 100 sets:
index = torch.randint(0, 100, (1000, ))

output = mean_aggr(x, index)  #  Output shape: [100, 64]

Notably, all aggregations share the same set of forward arguments, as described in detail in torch_geometric.nn.aggr.Aggregation.

Each of the provided aggregations can be used within MessagePassing as well as for hierachical/global pooling to obtain graph-level representations:

import torch
from torch_geometric.nn import MessagePassing

class MyConv(MessagePassing):
    def __init__(self, ...):
        # Use a learnable softmax neighborhood aggregation:
        super().__init__(aggr=aggr.SoftmaxAggregation(learn=True))

   def forward(self, x, edge_index):
       ....


class MyGNN(torch.nn.Module)
    def __init__(self, ...):
        super().__init__()

        self.conv = MyConv(...)
        # Use a global sort aggregation:
        self.global_pool = aggr.SortAggregation(k=4)
        self.classifier = torch.nn.Linear(...)

     def foward(self, x, edge_index, batch):
         x = self.conv(x, edge_index).relu()
         x = self.global_pool(x, batch)
         x = self.classifier(x)
         return x

In addition, the aggregation package of PyG introduces two new concepts: First, aggregations can be resolved from pure strings via a lookup table, following the design principles of the class-resolver library, e.g., by simply passing in "median" to the MessagePassing module. This will automatically resolve to the MedianAggregation class:

class MyConv(MessagePassing):
    def __init__(self, ...):
        super().__init__(aggr="median")

Secondly, multiple aggregations can be combined and stacked via the MultiAggregation module in order to enhance the representational power of GNNs (Corso et al. (2020), Tailor et al. (2021)):

class MyConv(MessagePassing):
    def __init__(self, ...):
        # Combines a set of aggregations and concatenates their results,
        # i.e. its output will be `[num_nodes, 3 * out_channels]` here.
        # Note that the interface also supports automatic resolution.
        super().__init__(aggr=aggr.MultiAggregation(
            ['mean', 'std', aggr.SoftmaxAggregation(learn=True)]))

Importantly, MultiAggregation provides various options to combine the outputs of its underlying aggegations (e.g., using concatenation, summation, attention, …) via its mode argument. The default mode performs concatenation ("cat"). For combining via attention, we need to additionally specify the in_channels out_channels, and num_heads:

multi_aggr = aggr.MultiAggregation(['mean', 'std'], in_channels=64,
                                   out_channels=64, num_heads=4))

If aggregations are given as a list, they will be automatically resolved to a MultiAggregation, e.g., aggr=['mean', 'std', 'median'].

Finally, we added full support for customization of aggregations into the SAGEConv layer — simply override its aggr argument and utilize the power of aggregation within your GNN.

Aggregation

An abstract base class for implementing custom aggregations.

MultiAggregation

Performs aggregations with one or more aggregators and combines aggregated results, as described in the "Principal Neighbourhood Aggregation for Graph Nets" and "Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions" papers.

SumAggregation

An aggregation operator that sums up features across a set of elements

MeanAggregation

An aggregation operator that averages features across a set of elements

MaxAggregation

An aggregation operator that takes the feature-wise maximum across a set of elements

MinAggregation

An aggregation operator that takes the feature-wise minimum across a set of elements

MulAggregation

An aggregation operator that multiples features across a set of elements

VarAggregation

An aggregation operator that takes the feature-wise variance across a set of elements

StdAggregation

An aggregation operator that takes the feature-wise standard deviation across a set of elements

SoftmaxAggregation

The softmax aggregation operator based on a temperature term, as described in the "DeeperGCN: All You Need to Train Deeper GCNs" paper

PowerMeanAggregation

The powermean aggregation operator based on a power term, as described in the "DeeperGCN: All You Need to Train Deeper GCNs" paper

MedianAggregation

An aggregation operator that returns the feature-wise median of a set.

QuantileAggregation

An aggregation operator that returns the feature-wise \(q\)-th quantile of a set \(\mathcal{X}\).

LSTMAggregation

Performs LSTM-style aggregation in which the elements to aggregate are interpreted as a sequence, as described in the "Inductive Representation Learning on Large Graphs" paper.

Set2Set

The Set2Set aggregation operator based on iterative content-based attention, as described in the "Order Matters: Sequence to sequence for Sets" paper

DegreeScalerAggregation

Combines one or more aggregators and transforms its output with one or more scalers as introduced in the "Principal Neighbourhood Aggregation for Graph Nets" paper.

SortAggregation

The pooling operator from the "An End-to-End Deep Learning Architecture for Graph Classification" paper, where node features are sorted in descending order based on their last feature channel.

GraphMultisetTransformer

The Graph Multiset Transformer pooling operator from the "Accurate Learning of Graph Representations with Graph Multiset Pooling" paper.

AttentionalAggregation

The soft attention aggregation layer from the "Graph Matching Networks for Learning the Similarity of Graph Structured Objects" paper

EquilibriumAggregation

The equilibrium aggregation layer from the "Equilibrium Aggregation: Encoding Sets via Optimization" paper.

class Aggregation[source]

An abstract base class for implementing custom aggregations.

Aggregation can be either performed via an index vector, which defines the mapping from input elements to their location in the output:


https://raw.githubusercontent.com/rusty1s/pytorch_scatter/master/docs/source/_figures/add.svg?sanitize=true

Notably, index does not have to be sorted:

# Feature matrix holding 10 elements with 64 features each:
x = torch.randn(10, 64)

# Assign each element to one of three sets:
index = torch.tensor([0, 0, 1, 0, 2, 0, 2, 1, 0, 2])

output = aggr(x, index)  #  Output shape: [3, 64]

Alternatively, aggregation can be achieved via a “compressed” index vector called ptr. Here, elements within the same set need to be grouped together in the input, and ptr defines their boundaries:

# Feature matrix holding 10 elements with 64 features each:
x = torch.randn(10, 64)

# Define the boundary indices for three sets:
ptr = torch.tensor([0, 4, 7, 10])

output = aggr(x, ptr=ptr)  #  Output shape: [4, 64]

Note that at least one of index or ptr must be defined.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or edge features \((|\mathcal{E}|, F_{in})\), index vector \((|\mathcal{V}|)\) or \((|\mathcal{E}|)\),

  • output: graph features \((|\mathcal{G}|, F_{out})\) or node features \((|\mathcal{V}|, F_{out})\)

forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor[source]
Parameters
  • x (torch.Tensor) – The source tensor.

  • index (torch.LongTensor, optional) – The indices of elements for applying the aggregation. One of index or ptr must be defined. (default: None)

  • ptr (torch.LongTensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of index or ptr must be defined. (default: None)

  • dim_size (int, optional) – The size of the output tensor at dimension dim after aggregation. (default: None)

  • dim (int, optional) – The dimension in which to aggregate. (default: -2)

static set_validate_args(value: bool)[source]

Sets whether validation is enabled or disabled.

The default behavior mimics Python’s assert` statement: validation is on by default, but is disabled if Python is run in optimized mode (via python -O). Validation may be expensive, so you may want to disable it once a model is working.

Parameters

value (bool) – Whether to enable validation.

class MultiAggregation(aggrs: List[Union[Aggregation, str]], aggrs_kwargs: Optional[List[Dict[str, Any]]] = None, mode: Optional[str] = 'cat', mode_kwargs: Optional[Dict[str, Any]] = None)[source]

Performs aggregations with one or more aggregators and combines aggregated results, as described in the “Principal Neighbourhood Aggregation for Graph Nets” and “Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions” papers.

Parameters
  • aggrs (list) – The list of aggregation schemes to use.

  • aggrs_kwargs (dict, optional) – Arguments passed to the respective aggregation function in case it gets automatically resolved. (default: None)

  • mode (string, optional) – The combine mode to use for combining aggregated results from multiple aggregations ("cat", "proj", "sum", "mean", "max", "min", "logsumexp", "std", "var", "attn"). (default: "cat")

  • mode_kwargs (dict, optional) – Arguments passed for the combine mode. When "proj" or "attn" is used as the combine mode, in_channels (int or tuple) and out_channels (int) are needed to be specified respectively for the size of each input sample to combine from the respective aggregation outputs and the size of each output sample after combination. When "attn" mode is used, num_heads (int) is needed to be specified for the number of parallel attention heads. (default: None)

class SumAggregation[source]

An aggregation operator that sums up features across a set of elements

\[\mathrm{sum}(\mathcal{X}) = \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i.\]
class MeanAggregation[source]

An aggregation operator that averages features across a set of elements

\[\mathrm{mean}(\mathcal{X}) = \frac{1}{|\mathcal{X}|} \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i.\]
class MaxAggregation[source]

An aggregation operator that takes the feature-wise maximum across a set of elements

\[\mathrm{max}(\mathcal{X}) = \max_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i.\]
class MinAggregation[source]

An aggregation operator that takes the feature-wise minimum across a set of elements

\[\mathrm{min}(\mathcal{X}) = \min_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i.\]
class MulAggregation[source]

An aggregation operator that multiples features across a set of elements

\[\mathrm{mul}(\mathcal{X}) = \prod_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i.\]
class VarAggregation(semi_grad: bool = False)[source]

An aggregation operator that takes the feature-wise variance across a set of elements

\[\mathrm{var}(\mathcal{X}) = \mathrm{mean}(\{ \mathbf{x}_i^2 : x \in \mathcal{X} \}) - \mathrm{mean}(\mathcal{X})^2.\]
Parameters

semi_grad (bool, optional) – If set to True, will turn off gradient calculation during \(E[X^2]\) computation. Therefore, only semi-gradients are used during backpropagation. Useful for saving memory and accelerating backward computation. (default: False)

class StdAggregation(semi_grad: bool = False)[source]

An aggregation operator that takes the feature-wise standard deviation across a set of elements

\[\mathrm{std}(\mathcal{X}) = \sqrt{\mathrm{var}(\mathcal{X})}.\]
Parameters

semi_grad (bool, optional) – If set to True, will turn off gradient calculation during \(E[X^2]\) computation. Therefore, only semi-gradients are used during backpropagation. Useful for saving memory and accelerating backward computation. (default: False)

class SoftmaxAggregation(t: float = 1.0, learn: bool = False, semi_grad: bool = False, channels: int = 1)[source]

The softmax aggregation operator based on a temperature term, as described in the “DeeperGCN: All You Need to Train Deeper GCNs” paper

\[\mathrm{softmax}(\mathcal{X}|t) = \sum_{\mathbf{x}_i\in\mathcal{X}} \frac{\exp(t\cdot\mathbf{x}_i)}{\sum_{\mathbf{x}_j\in\mathcal{X}} \exp(t\cdot\mathbf{x}_j)}\cdot\mathbf{x}_{i},\]

where \(t\) controls the softness of the softmax when aggregating over a set of features \(\mathcal{X}\).

Parameters
  • t (float, optional) – Initial inverse temperature for softmax aggregation. (default: 1.0)

  • learn (bool, optional) – If set to True, will learn the value t for softmax aggregation dynamically. (default: False)

  • semi_grad (bool, optional) – If set to True, will turn off gradient calculation during softmax computation. Therefore, only semi-gradients are used during backpropagation. Useful for saving memory and accelerating backward computation when t is not learnable. (default: False)

  • channels (int, optional) – Number of channels to learn from \(t\). If set to a value greater than 1, \(t\) will be learned per input feature channel. This requires compatible shapes for the input to the forward calculation. (default: 1)

class PowerMeanAggregation(p: float = 1.0, learn: bool = False, channels: int = 1)[source]

The powermean aggregation operator based on a power term, as described in the “DeeperGCN: All You Need to Train Deeper GCNs” paper

\[\mathrm{powermean}(\mathcal{X}|p) = \left(\frac{1}{|\mathcal{X}|} \sum_{\mathbf{x}_i\in\mathcal{X}}\mathbf{x}_i^{p}\right)^{1/p},\]

where \(p\) controls the power of the powermean when aggregating over a set of features \(\mathcal{X}\).

Parameters
  • p (float, optional) – Initial power for powermean aggregation. (default: 1.0)

  • learn (bool, optional) – If set to True, will learn the value p for powermean aggregation dynamically. (default: False)

  • channels (int, optional) – Number of channels to learn from \(p\). If set to a value greater than 1, \(p\) will be learned per input feature channel. This requires compatible shapes for the input to the forward calculation. (default: 1)

class MedianAggregation(fill_value: float = 0.0)[source]

An aggregation operator that returns the feature-wise median of a set. That is, for every feature \(d\), it computes

\[{\mathrm{median}(\mathcal{X})}_d = x_{\pi_i,d}\]

where \(x_{\pi_1,d} \le x_{\pi_2,d} \le \dots \le x_{\pi_n,d}\) and \(i = \lfloor \frac{n}{2} \rfloor\).

Note

If the median lies between two values, the lowest one is returned. To compute the midpoint (or other kind of interpolation) of the two values, use QuantileAggregation instead.

Parameters

fill_value (float, optional) – The default value in the case no entry is found for a given index (default: 0.0).

class QuantileAggregation(q: Union[float, List[float]], interpolation: str = 'linear', fill_value: float = 0.0)[source]

An aggregation operator that returns the feature-wise \(q\)-th quantile of a set \(\mathcal{X}\). That is, for every feature \(d\), it computes

\[\begin{split}{\mathrm{Q}_q(\mathcal{X})}_d = \begin{cases} x_{\pi_i,d} & i = q \cdot n, \\ f(x_{\pi_i,d}, x_{\pi_{i+1},d}) & i < q \cdot n < i + 1,\\ \end{cases}\end{split}\]

where \(x_{\pi_1,d} \le \dots \le x_{\pi_i,d} \le \dots \le x_{\pi_n,d}\) and \(f(a, b)\) is an interpolation function defined by interpolation.

Parameters
  • q (float or list) – The quantile value(s) \(q\). Can be a scalar or a list of scalars in the range \([0, 1]\). If more than a quantile is passed, the results are concatenated.

  • interpolation (str) –

    Interpolation method applied if the quantile point \(q\cdot n\) lies between two values \(a \le b\). Can be one of the following:

    • "lower": Returns the one with lowest value.

    • "higher": Returns the one with highest value.

    • "midpoint": Returns the average of the two values.

    • "nearest": Returns the one whose index is nearest to the quantile point.

    • "linear": Returns a linear combination of the two elements, defined as \(f(a, b) = a + (b - a)\cdot(q\cdot n - i)\).

    (default: "linear")

  • fill_value (float, optional) – The default value in the case no entry is found for a given index (default: 0.0).

class LSTMAggregation(in_channels: int, out_channels: int, **kwargs)[source]

Performs LSTM-style aggregation in which the elements to aggregate are interpreted as a sequence, as described in the “Inductive Representation Learning on Large Graphs” paper.

Warning

LSTMAggregation is not a permutation-invariant operator.

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • **kwargs (optional) – Additional arguments of torch.nn.LSTM.

class Set2Set(in_channels: int, processing_steps: int, **kwargs)[source]

The Set2Set aggregation operator based on iterative content-based attention, as described in the “Order Matters: Sequence to sequence for Sets” paper

\[ \begin{align}\begin{aligned}\mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1})\\\alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t)\\\mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i\\\mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t,\end{aligned}\end{align} \]

where \(\mathbf{q}^{*}_T\) defines the output of the layer with twice the dimensionality as the input.

Parameters
  • in_channels (int) – Size of each input sample.

  • processing_steps (int) – Number of iterations \(T\).

  • **kwargs (optional) – Additional arguments of torch.nn.LSTM.

class DegreeScalerAggregation(aggr: Union[str, List[str], Aggregation], scaler: Union[str, List[str]], deg: Tensor, train_norm: bool = False, aggr_kwargs: Optional[List[Dict[str, Any]]] = None)[source]

Combines one or more aggregators and transforms its output with one or more scalers as introduced in the “Principal Neighbourhood Aggregation for Graph Nets” paper. The scalers are normalised by the in-degree of the training set and so must be provided at time of construction. See torch_geometric.nn.conv.PNAConv for more information.

Parameters
  • aggr (string or list or Aggregation) – The aggregation scheme to use. See MessagePassing for more information.

  • scaler (str or list) – Set of scaling function identifiers, namely one or more of "identity", "amplification", "attenuation", "linear" and "inverse_linear".

  • deg (Tensor) – Histogram of in-degrees of nodes in the training set, used by scalers to normalize.

  • train_norm (bool, optional) – are trainable. (default: False)

  • aggr_kwargs (Dict[str, Any], optional) – Arguments passed to the respective aggregation function in case it gets automatically resolved. (default: None)

class SortAggregation(k: int)[source]

The pooling operator from the “An End-to-End Deep Learning Architecture for Graph Classification” paper, where node features are sorted in descending order based on their last feature channel. The first \(k\) nodes form the output of the layer.

Parameters

k (int) – The number of nodes to hold for each graph.

class GraphMultisetTransformer(in_channels: int, hidden_channels: int, out_channels: int, Conv: Optional[Type] = None, num_nodes: int = 300, pooling_ratio: float = 0.25, pool_sequences: List[str] = ['GMPool_G', 'SelfAtt', 'GMPool_I'], num_heads: int = 4, layer_norm: bool = False)[source]

The Graph Multiset Transformer pooling operator from the “Accurate Learning of Graph Representations with Graph Multiset Pooling” paper.

The Graph Multiset Transformer clusters nodes of the entire graph via attention-based pooling operations ("GMPool_G" or "GMPool_I"). In addition, self-attention ("SelfAtt") can be used to calculate the inter-relationships among nodes.

Parameters
  • in_channels (int) – Size of each input sample.

  • hidden_channels (int) – Size of each hidden sample.

  • out_channels (int) – Size of each output sample.

  • conv (Type, optional) – A graph neural network layer for calculating hidden representations of nodes for "GMPool_G" (one of GCNConv, GraphConv or GATConv). (default: GCNConv)

  • num_nodes (int, optional) – The number of average or maximum nodes. (default: 300)

  • pooling_ratio (float, optional) – Graph pooling ratio for each pooling. (default: 0.25)

  • pool_sequences ([str], optional) – A sequence of pooling layers consisting of Graph Multiset Transformer submodules (one of ["GMPool_I"], ["GMPool_G"], ["GMPool_G", "GMPool_I"], ["GMPool_G", "SelfAtt", "GMPool_I"] or ["GMPool_G", "SelfAtt", "SelfAtt", "GMPool_I"]). (default: ["GMPool_G", "SelfAtt", "GMPool_I"])

  • num_heads (int, optional) – Number of attention heads. (default: 4)

  • layer_norm (bool, optional) – If set to True, will make use of layer normalization. (default: False)

class AttentionalAggregation(gate_nn: Module, nn: Optional[Module] = None)[source]

The soft attention aggregation layer from the “Graph Matching Networks for Learning the Similarity of Graph Structured Objects” paper

\[\mathbf{r}_i = \sum_{n=1}^{N_i} \mathrm{softmax} \left( h_{\mathrm{gate}} ( \mathbf{x}_n ) \right) \cdot h_{\mathbf{\Theta}} ( \mathbf{x}_n ),\]

where \(h_{\mathrm{gate}} \colon \mathbb{R}^F \to \mathbb{R}\) and \(h_{\mathbf{\Theta}}\) denote neural networks, i.e. MLPs.

Parameters
  • gate_nn (torch.nn.Module) – A neural network \(h_{\mathrm{gate}}\) that computes attention scores by mapping node features x of shape [-1, in_channels] to shape [-1, 1] (for node-level gating) or [1, out_channels] (for feature-level gating), e.g., defined by torch.nn.Sequential.

  • nn (torch.nn.Module, optional) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features x of shape [-1, in_channels] to shape [-1, out_channels] before combining them with the attention scores, e.g., defined by torch.nn.Sequential. (default: None)

class EquilibriumAggregation(in_channels: int, out_channels: int, num_layers: List[int], grad_iter: int = 5, lamb: float = 0.1)[source]

The equilibrium aggregation layer from the “Equilibrium Aggregation: Encoding Sets via Optimization” paper. The output of this layer \(\mathbf{y}\) is defined implicitly via a potential function \(F(\mathbf{x}, \mathbf{y})\), a regularization term \(R(\mathbf{y})\), and the condition

\[\mathbf{y} = \min_\mathbf{y} R(\mathbf{y}) + \sum_{i} F(\mathbf{x}_i, \mathbf{y}).\]

The given implementation uses a ResNet-like model for the potential function and a simple \(L_2\) norm \(R(\mathbf{y}) = \textrm{softplus}(\lambda) \cdot {\| \mathbf{y} \|}^2_2\) for the regularizer with learnable weight \(\lambda\).

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • num_layers (List[int) – List of hidden channels in the potential function.

  • grad_iter (int) – The number of steps to take in the internal gradient descent. (default: 5)

  • lamb (float) – The initial regularization constant. (default: 0.1)

Normalization Layers

BatchNorm

Applies batch normalization over a batch of node features as described in the "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" paper

InstanceNorm

Applies instance normalization over each individual example in a batch of node features as described in the "Instance Normalization: The Missing Ingredient for Fast Stylization" paper

LayerNorm

Applies layer normalization over each individual example in a batch of node features as described in the "Layer Normalization" paper

GraphNorm

Applies graph normalization over individual graphs as described in the "GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training" paper

GraphSizeNorm

Applies Graph Size Normalization over each individual graph in a batch of node features as described in the "Benchmarking Graph Neural Networks" paper

PairNorm

Applies pair normalization over node features as described in the "PairNorm: Tackling Oversmoothing in GNNs" paper

MeanSubtractionNorm

Applies layer normalization by subtracting the mean from the inputs as described in the "Revisiting 'Over-smoothing' in Deep GCNs" paper

MessageNorm

Applies message normalization over the aggregated messages as described in the "DeeperGCNs: All You Need to Train Deeper GCNs" paper

DiffGroupNorm

The differentiable group normalization layer from the "Towards Deeper Graph Neural Networks with Differentiable Group Normalization" paper, which normalizes node features group-wise via a learnable soft cluster assignment

class BatchNorm(in_channels: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, allow_single_element: bool = False)[source]

Applies batch normalization over a batch of node features as described in the “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift” paper

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]

The mean and standard-deviation are calculated per-dimension over all nodes inside the mini-batch.

Parameters
  • in_channels (int) – Size of each input sample.

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • momentum (float, optional) – The value used for the running mean and running variance computation. (default: 0.1)

  • affine (bool, optional) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: True)

  • track_running_stats (bool, optional) – If set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: True)

  • allow_single_element (bool, optional) – If set to True, batches with only a single element will work as during in evaluation. That is the running mean and variance will be used. Requires track_running_stats=True. (default: False)

reset_parameters()[source]
forward(x: Tensor) Tensor[source]
class InstanceNorm(in_channels: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = False, track_running_stats: bool = False)[source]

Applies instance normalization over each individual example in a batch of node features as described in the “Instance Normalization: The Missing Ingredient for Fast Stylization” paper

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]

The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch.

Parameters
  • in_channels (int) – Size of each input sample.

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • momentum (float, optional) – The value used for the running mean and running variance computation. (default: 0.1)

  • affine (bool, optional) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: False)

  • track_running_stats (bool, optional) – If set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses instance statistics in both training and eval modes. (default: False)

forward(x: Tensor, batch: Optional[Tensor] = None) Tensor[source]
num_features: int
eps: float
momentum: float
affine: bool
track_running_stats: bool
class LayerNorm(in_channels: int, eps: float = 1e-05, affine: bool = True, mode: str = 'graph')[source]

Applies layer normalization over each individual example in a batch of node features as described in the “Layer Normalization” paper

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]

The mean and standard-deviation are calculated across all nodes and all node channels separately for each object in a mini-batch.

Parameters
  • in_channels (int) – Size of each input sample.

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • affine (bool, optional) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: True)

  • mode (str, optinal) – The normalization mode to use for layer normalization. ("graph" or "node"). If "graph" is used, each graph will be considered as an element to be normalized. If “node” is used, each node will be considered as an element to be normalized. (default: "graph")

reset_parameters()[source]
forward(x: Tensor, batch: Optional[Tensor] = None) Tensor[source]
class GraphNorm(in_channels: int, eps: float = 1e-05)[source]

Applies graph normalization over individual graphs as described in the “GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training” paper

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]} {\sqrt{\textrm{Var}[\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]] + \epsilon}} \odot \gamma + \beta\]

where \(\alpha\) denotes parameters that learn how much information to keep in the mean.

Parameters
  • in_channels (int) – Size of each input sample.

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

reset_parameters()[source]
forward(x: Tensor, batch: Optional[Tensor] = None) Tensor[source]
class GraphSizeNorm[source]

Applies Graph Size Normalization over each individual graph in a batch of node features as described in the “Benchmarking Graph Neural Networks” paper

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x}_i}{\sqrt{|\mathcal{V}|}}\]
forward(x: Tensor, batch: Optional[Tensor] = None) Tensor[source]
class PairNorm(scale: float = 1.0, scale_individually: bool = False, eps: float = 1e-05)[source]

Applies pair normalization over node features as described in the “PairNorm: Tackling Oversmoothing in GNNs” paper

\[ \begin{align}\begin{aligned}\begin{split}\mathbf{x}_i^c &= \mathbf{x}_i - \frac{1}{n} \sum_{i=1}^n \mathbf{x}_i \\\end{split}\\\mathbf{x}_i^{\prime} &= s \cdot \frac{\mathbf{x}_i^c}{\sqrt{\frac{1}{n} \sum_{i=1}^n {\| \mathbf{x}_i^c \|}^2_2}}\end{aligned}\end{align} \]
Parameters
  • scale (float, optional) – Scaling factor \(s\) of normalization. (default, 1.)

  • scale_individually (bool, optional) – If set to True, will compute the scaling step as \(\mathbf{x}^{\prime}_i = s \cdot \frac{\mathbf{x}_i^c}{{\| \mathbf{x}_i^c \|}_2}\). (default: False)

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

forward(x: Tensor, batch: Optional[Tensor] = None) Tensor[source]
class MeanSubtractionNorm[source]

Applies layer normalization by subtracting the mean from the inputs as described in the “Revisiting ‘Over-smoothing’ in Deep GCNs” paper

\[\mathbf{x}_i = \mathbf{x}_i - \frac{1}{|\mathcal{V}|} \sum_{j \in \mathcal{V}} \mathbf{x}_j\]
reset_parameters()[source]
forward(x: Tensor, batch: Optional[Tensor] = None, dim_size: Optional[int] = None) Tensor[source]
class MessageNorm(learn_scale: bool = False)[source]

Applies message normalization over the aggregated messages as described in the “DeeperGCNs: All You Need to Train Deeper GCNs” paper

\[\mathbf{x}_i^{\prime} = \mathrm{MLP} \left( \mathbf{x}_{i} + s \cdot {\| \mathbf{x}_i \|}_2 \cdot \frac{\mathbf{m}_{i}}{{\|\mathbf{m}_i\|}_2} \right)\]
Parameters

learn_scale (bool, optional) – If set to True, will learn the scaling factor \(s\) of message normalization. (default: False)

reset_parameters()[source]
forward(x: Tensor, msg: Tensor, p: float = 2.0) Tensor[source]
class DiffGroupNorm(in_channels: int, groups: int, lamda: float = 0.01, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True)[source]

The differentiable group normalization layer from the “Towards Deeper Graph Neural Networks with Differentiable Group Normalization” paper, which normalizes node features group-wise via a learnable soft cluster assignment

\[\mathbf{S} = \text{softmax} (\mathbf{X} \mathbf{W})\]

where \(\mathbf{W} \in \mathbb{R}^{F \times G}\) denotes a trainable weight matrix mapping each node into one of \(G\) clusters. Normalization is then performed group-wise via:

\[\mathbf{X}^{\prime} = \mathbf{X} + \lambda \sum_{i = 1}^G \text{BatchNorm}(\mathbf{S}[:, i] \odot \mathbf{X})\]
Parameters
  • in_channels (int) – Size of each input sample \(F\).

  • groups (int) – The number of groups \(G\).

  • lamda (float, optional) – The balancing factor \(\lambda\) between input embeddings and normalized embeddings. (default: 0.01)

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • momentum (float, optional) – The value used for the running mean and running variance computation. (default: 0.1)

  • affine (bool, optional) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: True)

  • track_running_stats (bool, optional) – If set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: True)

reset_parameters()[source]
forward(x: Tensor) Tensor[source]
static group_distance_ratio(x: Tensor, y: Tensor, eps: float = 1e-05) float[source]

Measures the ratio of inter-group distance over intra-group distance

\[R_{\text{Group}} = \frac{\frac{1}{(C-1)^2} \sum_{i!=j} \frac{1}{|\mathbf{X}_i||\mathbf{X}_j|} \sum_{\mathbf{x}_{iv} \in \mathbf{X}_i } \sum_{\mathbf{x}_{jv^{\prime}} \in \mathbf{X}_j} {\| \mathbf{x}_{iv} - \mathbf{x}_{jv^{\prime}} \|}_2 }{ \frac{1}{C} \sum_{i} \frac{1}{{|\mathbf{X}_i|}^2} \sum_{\mathbf{x}_{iv}, \mathbf{x}_{iv^{\prime}} \in \mathbf{X}_i } {\| \mathbf{x}_{iv} - \mathbf{x}_{iv^{\prime}} \|}_2 }\]

where \(\mathbf{X}_i\) denotes the set of all nodes that belong to class \(i\), and \(C\) denotes the total number of classes in y.

Pooling Layers

global_add_pool

Returns batch-wise graph-level-outputs by adding node features across the node dimension, so that for a single graph \(\mathcal{G}_i\) its output is computed by

global_mean_pool

Returns batch-wise graph-level-outputs by averaging node features across the node dimension, so that for a single graph \(\mathcal{G}_i\) its output is computed by

global_max_pool

Returns batch-wise graph-level-outputs by taking the channel-wise maximum across the node dimension, so that for a single graph \(\mathcal{G}_i\) its output is computed by

TopKPooling

\(\mathrm{top}_k\) pooling operator from the "Graph U-Nets", "Towards Sparse Hierarchical Graph Classifiers" and "Understanding Attention and Generalization in Graph Neural Networks" papers

SAGPooling

The self-attention pooling operator from the "Self-Attention Graph Pooling" and "Understanding Attention and Generalization in Graph Neural Networks" papers

EdgePooling

The edge pooling operator from the "Towards Graph Pooling by Edge Contraction" and "Edge Contraction Pooling for Graph Neural Networks" papers.

ASAPooling

The Adaptive Structure Aware Pooling operator from the "ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations" paper.

PANPooling

The path integral based pooling operator from the "Path Integral Based Convolution and Pooling for Graph Neural Networks" paper.

MemPooling

Memory based pooling layer from "Memory-Based Graph Networks" paper, which learns a coarsened graph representation based on soft cluster assignments

max_pool

Pools and coarsens a graph given by the torch_geometric.data.Data object according to the clustering defined in cluster.

avg_pool

Pools and coarsens a graph given by the torch_geometric.data.Data object according to the clustering defined in cluster.

max_pool_x

Max-Pools node features according to the clustering defined in cluster.

max_pool_neighbor_x

Max pools neighboring node features, where each feature in data.x is replaced by the feature value with the maximum value from the central node and its neighbors.

avg_pool_x

Average pools node features according to the clustering defined in cluster.

avg_pool_neighbor_x

Average pools neighboring node features, where each feature in data.x is replaced by the average feature values from the central node and its neighbors.

graclus

A greedy clustering algorithm from the "Weighted Graph Cuts without Eigenvectors: A Multilevel Approach" paper of picking an unmarked vertex and matching it with one of its unmarked neighbors (that maximizes its edge weight).

voxel_grid

Voxel grid pooling from the, e.g., Dynamic Edge-Conditioned Filters in Convolutional Networks on Graphs paper, which overlays a regular grid of user-defined size over a point cloud and clusters all points within the same voxel.

fps

A sampling algorithm from the "PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" paper, which iteratively samples the most distant point with regard to the rest points.

knn

Finds for each element in y the k nearest points in x.

knn_graph

Computes graph edges to the nearest k points.

radius

Finds for each element in y all points in x within distance r.

radius_graph

Computes graph edges to all points within a given distance.

nearest

Clusters points in x together which are nearest to a given query point in y.

global_add_pool(x: Tensor, batch: Optional[Tensor], size: Optional[int] = None) Tensor[source]

Returns batch-wise graph-level-outputs by adding node features across the node dimension, so that for a single graph \(\mathcal{G}_i\) its output is computed by

\[\mathbf{r}_i = \sum_{n=1}^{N_i} \mathbf{x}_n.\]

Functional method of the SumAggregation module.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (int, optional) – Batch-size \(B\). Automatically calculated if not given. (default: None)

global_mean_pool(x: Tensor, batch: Optional[Tensor], size: Optional[int] = None) Tensor[source]

Returns batch-wise graph-level-outputs by averaging node features across the node dimension, so that for a single graph \(\mathcal{G}_i\) its output is computed by

\[\mathbf{r}_i = \frac{1}{N_i} \sum_{n=1}^{N_i} \mathbf{x}_n.\]

Functional method of the MeanAggregation module.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (int, optional) – Batch-size \(B\). Automatically calculated if not given. (default: None)

global_max_pool(x: Tensor, batch: Optional[Tensor], size: Optional[int] = None) Tensor[source]

Returns batch-wise graph-level-outputs by taking the channel-wise maximum across the node dimension, so that for a single graph \(\mathcal{G}_i\) its output is computed by

\[\mathbf{r}_i = \mathrm{max}_{n=1}^{N_i} \, \mathbf{x}_n.\]

Functional method of the MaxAggregation module.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (int, optional) – Batch-size \(B\). Automatically calculated if not given. (default: None)

class TopKPooling(in_channels: int, ratio: ~typing.Union[int, float] = 0.5, min_score: ~typing.Optional[float] = None, multiplier: float = 1.0, nonlinearity: ~typing.Callable = <built-in method tanh of type object>)[source]

\(\mathrm{top}_k\) pooling operator from the “Graph U-Nets”, “Towards Sparse Hierarchical Graph Classifiers” and “Understanding Attention and Generalization in Graph Neural Networks” papers

if min_score \(\tilde{\alpha}\) is None:

\[ \begin{align}\begin{aligned}\mathbf{y} &= \frac{\mathbf{X}\mathbf{p}}{\| \mathbf{p} \|}\\\mathbf{i} &= \mathrm{top}_k(\mathbf{y})\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}\end{aligned}\end{align} \]

if min_score \(\tilde{\alpha}\) is a value in [0, 1]:

\[ \begin{align}\begin{aligned}\mathbf{y} &= \mathrm{softmax}(\mathbf{X}\mathbf{p})\\\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},\end{aligned}\end{align} \]

where nodes are dropped based on a learnable projection score \(\mathbf{p}\).

Parameters
  • in_channels (int) – Size of each input sample.

  • ratio (float or int) – Graph pooling ratio, which is used to compute \(k = \lceil \mathrm{ratio} \cdot N \rceil\), or the value of \(k\) itself, depending on whether the type of ratio is float or int. This value is ignored if min_score is not None. (default: 0.5)

  • min_score (float, optional) – Minimal node score \(\tilde{\alpha}\) which is used to compute indices of pooled nodes \(\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}\). When this value is not None, the ratio argument is ignored. (default: None)

  • multiplier (float, optional) – Coefficient by which features gets multiplied after pooling. This can be useful for large graphs and when min_score is used. (default: 1)

  • nonlinearity (torch.nn.functional, optional) – The nonlinearity to use. (default: torch.tanh)

reset_parameters()[source]
forward(x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None, batch: Optional[Tensor] = None, attn: Optional[Tensor] = None) Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor, Tensor][source]
class SAGPooling(in_channels: int, ratio: ~typing.Union[float, int] = 0.5, GNN: ~torch.nn.modules.module.Module = <class 'torch_geometric.nn.conv.graph_conv.GraphConv'>, min_score: ~typing.Optional[float] = None, multiplier: float = 1.0, nonlinearity: ~typing.Callable = <built-in method tanh of type object>, **kwargs)[source]

The self-attention pooling operator from the “Self-Attention Graph Pooling” and “Understanding Attention and Generalization in Graph Neural Networks” papers

if min_score \(\tilde{\alpha}\) is None:

\[ \begin{align}\begin{aligned}\mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A})\\\mathbf{i} &= \mathrm{top}_k(\mathbf{y})\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}\end{aligned}\end{align} \]

if min_score \(\tilde{\alpha}\) is a value in [0, 1]:

\[ \begin{align}\begin{aligned}\mathbf{y} &= \mathrm{softmax}(\textrm{GNN}(\mathbf{X},\mathbf{A}))\\\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}.\end{aligned}\end{align} \]

Projections scores are learned based on a graph neural network layer.

Parameters
  • in_channels (int) – Size of each input sample.

  • ratio (float or int) – Graph pooling ratio, which is used to compute \(k = \lceil \mathrm{ratio} \cdot N \rceil\), or the value of \(k\) itself, depending on whether the type of ratio is float or int. This value is ignored if min_score is not None. (default: 0.5)

  • GNN (torch.nn.Module, optional) – A graph neural network layer for calculating projection scores (one of torch_geometric.nn.conv.GraphConv, torch_geometric.nn.conv.GCNConv, torch_geometric.nn.conv.GATConv or torch_geometric.nn.conv.SAGEConv). (default: torch_geometric.nn.conv.GraphConv)

  • min_score (float, optional) – Minimal node score \(\tilde{\alpha}\) which is used to compute indices of pooled nodes \(\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}\). When this value is not None, the ratio argument is ignored. (default: None)

  • multiplier (float, optional) – Coefficient by which features gets multiplied after pooling. This can be useful for large graphs and when min_score is used. (default: 1)

  • nonlinearity (torch.nn.functional, optional) – The nonlinearity to use. (default: torch.tanh)

  • **kwargs (optional) – Additional parameters for initializing the graph neural network layer.

reset_parameters()[source]
forward(x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None, batch: Optional[Tensor] = None, attn: Optional[Tensor] = None) Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor, Tensor][source]
class EdgePooling(in_channels: int, edge_score_method: Optional[Callable] = None, dropout: Optional[float] = 0.0, add_to_edge_score: float = 0.5)[source]

The edge pooling operator from the “Towards Graph Pooling by Edge Contraction” and “Edge Contraction Pooling for Graph Neural Networks” papers.

In short, a score is computed for each edge. Edges are contracted iteratively according to that score unless one of their nodes has already been part of a contracted edge.

To duplicate the configuration from the “Towards Graph Pooling by Edge Contraction” paper, use either EdgePooling.compute_edge_score_softmax() or EdgePooling.compute_edge_score_tanh(), and set add_to_edge_score to 0.0.

To duplicate the configuration from the “Edge Contraction Pooling for Graph Neural Networks” paper, set dropout to 0.2.

Parameters
  • in_channels (int) – Size of each input sample.

  • edge_score_method (function, optional) – The function to apply to compute the edge score from raw edge scores. By default, this is the softmax over all incoming edges for each node. This function takes in a raw_edge_score tensor of shape [num_nodes], an edge_index tensor and the number of nodes num_nodes, and produces a new tensor of the same size as raw_edge_score describing normalized edge scores. Included functions are EdgePooling.compute_edge_score_softmax(), EdgePooling.compute_edge_score_tanh(), and EdgePooling.compute_edge_score_sigmoid(). (default: EdgePooling.compute_edge_score_softmax())

  • dropout (float, optional) – The probability with which to drop edge scores during training. (default: 0.0)

  • add_to_edge_score (float, optional) – This is added to each computed edge score. Adding this greatly helps with unpool stability. (default: 0.5)

reset_parameters()[source]
static compute_edge_score_softmax(raw_edge_score: Tensor, edge_index: Tensor, num_nodes: int) Tensor[source]
static compute_edge_score_tanh(raw_edge_score: Tensor, edge_index: Optional[Tensor] = None, num_nodes: Optional[int] = None) Tensor[source]
static compute_edge_score_sigmoid(raw_edge_score: Tensor, edge_index: Optional[Tensor] = None, num_nodes: Optional[int] = None) Tensor[source]
forward(x: Tensor, edge_index: Tensor, batch: Tensor) Tuple[Tensor, Tensor, Tensor, UnpoolInfo][source]

Forward computation which computes the raw edge score, normalizes it, and merges the edges.

Parameters
  • x (Tensor) – The node features.

  • edge_index (LongTensor) – The edge indices.

  • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

Return types:
  • x (Tensor) - The pooled node features.

  • edge_index (LongTensor) - The coarsened edge indices.

  • batch (LongTensor) - The coarsened batch vector.

  • unpool_info (UnpoolInfo) - Information that is consumed by EdgePooling.unpool() for unpooling.

unpool(x: Tensor, unpool_info: UnpoolInfo) Tuple[Tensor, Tensor, Tensor][source]

Unpools a previous edge pooling step.

For unpooling, x should be of same shape as those produced by this layer’s forward() function. Then, it will produce an unpooled x in addition to edge_index and batch.

Parameters
  • x (Tensor) – The node features.

  • unpool_info (UnpoolInfo) – Information that has been produced by EdgePooling.forward().

Return types:
  • x (Tensor) - The unpooled node features.

  • edge_index (LongTensor) - The new edge indices.

  • batch (LongTensor) - The new batch vector.

class ASAPooling(in_channels: int, ratio: Union[float, int] = 0.5, GNN: Optional[Callable] = None, dropout: float = 0.0, negative_slope: float = 0.2, add_self_loops: bool = False, **kwargs)[source]

The Adaptive Structure Aware Pooling operator from the “ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations” paper.

Parameters
  • in_channels (int) – Size of each input sample.

  • ratio (float or int) – Graph pooling ratio, which is used to compute \(k = \lceil \mathrm{ratio} \cdot N \rceil\), or the value of \(k\) itself, depending on whether the type of ratio is float or int. (default: 0.5)

  • GNN (torch.nn.Module, optional) – A graph neural network layer for using intra-cluster properties. Especially helpful for graphs with higher degree of neighborhood (one of torch_geometric.nn.conv.GraphConv, torch_geometric.nn.conv.GCNConv or any GNN which supports the edge_weight parameter). (default: None)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • add_self_loops (bool, optional) – If set to True, will add self loops to the new graph connectivity. (default: False)

  • **kwargs (optional) – Additional parameters for initializing the graph neural network layer.

Returns

A tuple of tensors containing

  • x (Tensor): The pooled node embeddings.

  • edge_index (Tensor): The coarsened graph connectivity.

  • edge_weight (Tensor): The edge weights corresponding to the coarsened graph connectivity.

  • batch (Tensor): The pooled batch vector.

  • perm (Tensor): The top-\(k\) node indices of nodes which are kept after pooling.

reset_parameters()[source]
forward(x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None, batch: Optional[Tensor] = None) Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor][source]
jittable() ASAPooling[source]
class PANPooling(in_channels: int, ratio: float = 0.5, min_score: ~typing.Optional[float] = None, multiplier: float = 1.0, nonlinearity: ~typing.Callable = <built-in method tanh of type object>)[source]

The path integral based pooling operator from the “Path Integral Based Convolution and Pooling for Graph Neural Networks” paper. PAN pooling performs top-\(k\) pooling where global node importance is measured based on node features and the MET matrix:

\[{\rm score} = \beta_1 \mathbf{X} \cdot \mathbf{p} + \beta_2 {\rm deg}(M)\]
Parameters
  • in_channels (int) – Size of each input sample.

  • ratio (float) – Graph pooling ratio, which is used to compute \(k = \lceil \mathrm{ratio} \cdot N \rceil\). This value is ignored if min_score is not None. (default: 0.5)

  • min_score (float, optional) – Minimal node score \(\tilde{\alpha}\) which is used to compute indices of pooled nodes \(\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}\). When this value is not None, the ratio argument is ignored. (default: None)

  • multiplier (float, optional) – Coefficient by which features gets multiplied after pooling. This can be useful for large graphs and when min_score is used. (default: 1.0)

  • nonlinearity (torch.nn.functional, optional) – The nonlinearity to use. (default: torch.tanh)

reset_parameters()[source]
forward(x: Tensor, M: SparseTensor, batch: Optional[Tensor] = None) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor][source]
class MemPooling(in_channels: int, out_channels: int, heads: int, num_clusters: int, tau: float = 1.0)[source]

Memory based pooling layer from “Memory-Based Graph Networks” paper, which learns a coarsened graph representation based on soft cluster assignments

\[ \begin{align}\begin{aligned}S_{i,j}^{(h)} &= \frac{ (1+{\| \mathbf{x}_i-\mathbf{k}^{(h)}_j \|}^2 / \tau)^{ -\frac{1+\tau}{2}}}{ \sum_{k=1}^K (1 + {\| \mathbf{x}_i-\mathbf{k}^{(h)}_k \|}^2 / \tau)^{ -\frac{1+\tau}{2}}}\\\mathbf{S} &= \textrm{softmax}(\textrm{Conv2d} (\Vert_{h=1}^H \mathbf{S}^{(h)})) \in \mathbb{R}^{N \times K}\\\mathbf{X}^{\prime} &= \mathbf{S}^{\top} \mathbf{X} \mathbf{W} \in \mathbb{R}^{K \times F^{\prime}}\end{aligned}\end{align} \]

Where \(H\) denotes the number of heads, and \(K\) denotes the number of clusters.

Parameters
  • in_channels (int) – Size of each input sample \(F\).

  • out_channels (int) – Size of each output sample \(F^{\prime}\).

  • heads (int) – The number of heads \(H\).

  • num_clusters (int) – number of clusters \(K\) per head.

  • tau (int, optional) – The temperature \(\tau\). (default: 1.)

reset_parameters()[source]
static kl_loss(S: Tensor) Tensor[source]

The additional KL divergence-based loss

\[ \begin{align}\begin{aligned}P_{i,j} &= \frac{S_{i,j}^2 / \sum_{n=1}^N S_{n,j}}{\sum_{k=1}^K S_{i,k}^2 / \sum_{n=1}^N S_{n,k}}\\\mathcal{L}_{\textrm{KL}} &= \textrm{KLDiv}(\mathbf{P} \Vert \mathbf{S})\end{aligned}\end{align} \]
forward(x: Tensor, batch: Optional[Tensor] = None, mask: Optional[Tensor] = None) Tuple[Tensor, Tensor][source]
Parameters
  • x (Tensor) – Dense or sparse node feature tensor \(\mathbf{X} \in \mathbb{R}^{N \times F}\) or \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), respectively.

  • batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. This argument should be just to separate graphs when using sparse node features. (default: None)

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\), which indicates valid nodes for each graph when using dense node features. (default: None)

max_pool(cluster: Tensor, data: Data, transform: Optional[Callable] = None) Data[source]

Pools and coarsens a graph given by the torch_geometric.data.Data object according to the clustering defined in cluster. All nodes within the same cluster will be represented as one node. Final node features are defined by the maximum features of all nodes within the same cluster, node positions are averaged and edge indices are defined to be the union of the edge indices of all nodes within the same cluster.

Parameters
  • cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

  • data (Data) – Graph data object.

  • transform (callable, optional) – A function/transform that takes in the coarsened and pooled torch_geometric.data.Data object and returns a transformed version. (default: None)

Return type

torch_geometric.data.Data

avg_pool(cluster: Tensor, data: Data, transform: Optional[Callable] = None) Data[source]

Pools and coarsens a graph given by the torch_geometric.data.Data object according to the clustering defined in cluster. Final node features are defined by the average features of all nodes within the same cluster. See torch_geometric.nn.pool.max_pool() for more details.

Parameters
  • cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

  • data (Data) – Graph data object.

  • transform (callable, optional) – A function/transform that takes in the coarsened and pooled torch_geometric.data.Data object and returns a transformed version. (default: None)

Return type

torch_geometric.data.Data

max_pool_x(cluster: Tensor, x: Tensor, batch: Tensor, size: Optional[int] = None) Tuple[Tensor, Optional[Tensor]][source]

Max-Pools node features according to the clustering defined in cluster.

Parameters
  • cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (int, optional) – The maximum number of clusters in a single example. This property is useful to obtain a batch-wise dense representation, e.g. for applying FC layers, but should only be used if the size of the maximum number of clusters per example is known in advance. (default: None)

Return type

(Tensor, LongTensor) if size is None, else Tensor

max_pool_neighbor_x(data: Data, flow: Optional[str] = 'source_to_target') Data[source]

Max pools neighboring node features, where each feature in data.x is replaced by the feature value with the maximum value from the central node and its neighbors.

avg_pool_x(cluster: Tensor, x: Tensor, batch: Tensor, size: Optional[int] = None) Tuple[Tensor, Optional[Tensor]][source]

Average pools node features according to the clustering defined in cluster. See torch_geometric.nn.pool.max_pool_x() for more details.

Parameters
  • cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (int, optional) – The maximum number of clusters in a single example. (default: None)

Return type

(Tensor, LongTensor) if size is None, else Tensor

avg_pool_neighbor_x(data: Data, flow: Optional[str] = 'source_to_target') Data[source]

Average pools neighboring node features, where each feature in data.x is replaced by the average feature values from the central node and its neighbors.

graclus(edge_index, weight: Optional[Tensor] = None, num_nodes: Optional[int] = None)[source]

A greedy clustering algorithm from the “Weighted Graph Cuts without Eigenvectors: A Multilevel Approach” paper of picking an unmarked vertex and matching it with one of its unmarked neighbors (that maximizes its edge weight). The GPU algorithm is adapted from the “A GPU Algorithm for Greedy Graph Matching” paper.

Parameters
  • edge_index (LongTensor) – The edge indices.

  • weight (Tensor, optional) – One-dimensional edge weights. (default: None)

  • num_nodes (int, optional) – The number of nodes, i.e. max_val + 1 of edge_index. (default: None)

Return type

LongTensor

voxel_grid(pos: Tensor, size: Union[float, List[float], Tensor], batch: Optional[Tensor] = None, start: Optional[Union[float, List[float], Tensor]] = None, end: Optional[Union[float, List[float], Tensor]] = None) Tensor[source]

Voxel grid pooling from the, e.g., Dynamic Edge-Conditioned Filters in Convolutional Networks on Graphs paper, which overlays a regular grid of user-defined size over a point cloud and clusters all points within the same voxel.

Parameters
  • pos (Tensor) – Node position matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times D}\).

  • size (float or [float] or Tensor) – Size of a voxel (in each dimension).

  • batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots,B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • start (float or [float] or Tensor, optional) – Start coordinates of the grid (in each dimension). If set to None, will be set to the minimum coordinates found in pos. (default: None)

  • end (float or [float] or Tensor, optional) – End coordinates of the grid (in each dimension). If set to None, will be set to the maximum coordinates found in pos. (default: None)

Return type

LongTensor

fps(x: Tensor, batch: Optional[Tensor] = None, ratio: float = 0.5, random_start: bool = True) Tensor[source]

A sampling algorithm from the “PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space” paper, which iteratively samples the most distant point with regard to the rest points.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • ratio (float, optional) – Sampling ratio. (default: 0.5)

  • random_start (bool, optional) – If set to False, use the first node in \(\mathbf{X}\) as starting node. (default: obj:True)

Return type

LongTensor

import torch
from torch_geometric.nn import fps

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch = torch.tensor([0, 0, 0, 0])
index = fps(x, batch, ratio=0.5)
knn(x: Tensor, y: Tensor, k: int, batch_x: Optional[Tensor] = None, batch_y: Optional[Tensor] = None, cosine: bool = False, num_workers: int = 1) Tensor[source]

Finds for each element in y the k nearest points in x.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • y (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{M \times F}\).

  • k (int) – The number of neighbors.

  • batch_x (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • batch_y (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M\), which assigns each node to a specific example. (default: None)

  • cosine (boolean, optional) – If True, will use the cosine distance instead of euclidean distance to find nearest neighbors. (default: False)

  • num_workers (int) – Number of workers to use for computation. Has no effect in case batch_x or batch_y is not None, or the input lies on the GPU. (default: 1)

Return type

LongTensor

import torch
from torch_geometric.nn import knn

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
batch_y = torch.tensor([0, 0])
assign_index = knn(x, y, 2, batch_x, batch_y)
knn_graph(x: Tensor, k: int, batch: Optional[Tensor] = None, loop: bool = False, flow: str = 'source_to_target', cosine: bool = False, num_workers: int = 1) Tensor[source]

Computes graph edges to the nearest k points.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • k (int) – The number of neighbors.

  • batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • loop (bool, optional) – If True, the graph will contain self-loops. (default: False)

  • flow (string, optional) – The flow direction when using in combination with message passing ("source_to_target" or "target_to_source"). (default: "source_to_target")

  • cosine (boolean, optional) – If True, will use the cosine distance instead of euclidean distance to find nearest neighbors. (default: False)

  • num_workers (int) – Number of workers to use for computation. Has no effect in case batch is not None, or the input lies on the GPU. (default: 1)

Return type

LongTensor

import torch
from torch_geometric.nn import knn_graph

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch = torch.tensor([0, 0, 0, 0])
edge_index = knn_graph(x, k=2, batch=batch, loop=False)
radius(x: Tensor, y: Tensor, r: float, batch_x: Optional[Tensor] = None, batch_y: Optional[Tensor] = None, max_num_neighbors: int = 32, num_workers: int = 1) Tensor[source]

Finds for each element in y all points in x within distance r.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • y (Tensor) – Node feature matrix \(\mathbf{Y} \in \mathbb{R}^{M \times F}\).

  • r (float) – The radius.

  • batch_x (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • batch_y (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M\), which assigns each node to a specific example. (default: None)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to return for each element in y. (default: 32)

  • num_workers (int) – Number of workers to use for computation. Has no effect in case batch_x or batch_y is not None, or the input lies on the GPU. (default: 1)

Return type

LongTensor

import torch
from torch_geometric.nn import radius

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
batch_y = torch.tensor([0, 0])
assign_index = radius(x, y, 1.5, batch_x, batch_y)
radius_graph(x: Tensor, r: float, batch: Optional[Tensor] = None, loop: bool = False, max_num_neighbors: int = 32, flow: str = 'source_to_target', num_workers: int = 1) Tensor[source]

Computes graph edges to all points within a given distance.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • r (float) – The radius.

  • batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • loop (bool, optional) – If True, the graph will contain self-loops. (default: False)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to return for each element in y. (default: 32)

  • flow (string, optional) – The flow direction when using in combination with message passing ("source_to_target" or "target_to_source"). (default: "source_to_target")

  • num_workers (int) – Number of workers to use for computation. Has no effect in case batch is not None, or the input lies on the GPU. (default: 1)

Return type

LongTensor

import torch
from torch_geometric.nn import radius_graph

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch = torch.tensor([0, 0, 0, 0])
edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
nearest(x: Tensor, y: Tensor, batch_x: Optional[Tensor] = None, batch_y: Optional[Tensor] = None) Tensor[source]

Clusters points in x together which are nearest to a given query point in y.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • y (Tensor) – Node feature matrix \(\mathbf{Y} \in \mathbb{R}^{M \times F}\).

  • batch_x (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • batch_y (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M\), which assigns each node to a specific example. (default: None)

Return type

LongTensor

import torch
from torch_geometric.nn import nearest

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
batch_y = torch.tensor([0, 0])
cluster = nearest(x, y, batch_x, batch_y)

Unpooling Layers

knn_interpolate

The k-NN interpolation from the "PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" paper.

knn_interpolate(x: Tensor, pos_x: Tensor, pos_y: Tensor, batch_x: Optional[Tensor] = None, batch_y: Optional[Tensor] = None, k: int = 3, num_workers: int = 1)[source]

The k-NN interpolation from the “PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space” paper. For each point \(y\) with position \(\mathbf{p}(y)\), its interpolated features \(\mathbf{f}(y)\) are given by

\[\mathbf{f}(y) = \frac{\sum_{i=1}^k w(x_i) \mathbf{f}(x_i)}{\sum_{i=1}^k w(x_i)} \textrm{, where } w(x_i) = \frac{1}{d(\mathbf{p}(y), \mathbf{p}(x_i))^2}\]

and \(\{ x_1, \ldots, x_k \}\) denoting the \(k\) nearest points to \(y\).

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • pos_x (Tensor) – Node position matrix \(\in \mathbb{R}^{N \times d}\).

  • pos_y (Tensor) – Upsampled node position matrix \(\in \mathbb{R}^{M \times d}\).

  • batch_x (LongTensor, optional) – Batch vector \(\mathbf{b_x} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node from \(\mathbf{X}\) to a specific example. (default: None)

  • batch_y (LongTensor, optional) – Batch vector \(\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node from \(\mathbf{Y}\) to a specific example. (default: None)

  • k (int, optional) – Number of neighbors. (default: 3)

  • num_workers (int) – Number of workers to use for computation. Has no effect in case batch_x or batch_y is not None, or the input lies on the GPU. (default: 1)

Models

MLP

A Multi-Layer Perception (MLP) model.

GCN

The Graph Neural Network from the "Semi-supervised Classification with Graph Convolutional Networks" paper, using the GCNConv operator for message passing.

GraphSAGE

The Graph Neural Network from the "Inductive Representation Learning on Large Graphs" paper, using the SAGEConv operator for message passing.

GIN

The Graph Neural Network from the "How Powerful are Graph Neural Networks?" paper, using the GINConv operator for message passing.

GAT

The Graph Neural Network from "Graph Attention Networks" or "How Attentive are Graph Attention Networks?" papers, using the GATConv or GATv2Conv operator for message passing, respectively.

PNA

The Graph Neural Network from the "Principal Neighbourhood Aggregation for Graph Nets" paper, using the PNAConv operator for message passing.

EdgeCNN

The Graph Neural Network from the "Dynamic Graph CNN for Learning on Point Clouds" paper, using the EdgeConv operator for message passing.

JumpingKnowledge

The Jumping Knowledge layer aggregation module from the "Representation Learning on Graphs with Jumping Knowledge Networks" paper based on either concatenation ("cat")

Node2Vec

The Node2Vec model from the "node2vec: Scalable Feature Learning for Networks" paper where random walks of length walk_length are sampled in a given graph, and node embeddings are learned via negative sampling optimization.

DeepGraphInfomax

The Deep Graph Infomax model from the "Deep Graph Infomax" paper based on user-defined encoder and summary model \(\mathcal{E}\) and \(\mathcal{R}\) respectively, and a corruption function \(\mathcal{C}\).

InnerProductDecoder

The inner product decoder from the "Variational Graph Auto-Encoders" paper

GAE

The Graph Auto-Encoder model from the "Variational Graph Auto-Encoders" paper based on user-defined encoder and decoder models.

VGAE

The Variational Graph Auto-Encoder model from the "Variational Graph Auto-Encoders" paper.

ARGA

The Adversarially Regularized Graph Auto-Encoder model from the "Adversarially Regularized Graph Autoencoder for Graph Embedding" paper.

ARGVA

The Adversarially Regularized Variational Graph Auto-Encoder model from the "Adversarially Regularized Graph Autoencoder for Graph Embedding" paper.

SignedGCN

The signed graph convolutional network model from the "Signed Graph Convolutional Network" paper.

RENet

The Recurrent Event Network model from the "Recurrent Event Network for Reasoning over Temporal Knowledge Graphs" paper

GraphUNet

The Graph U-Net model from the "Graph U-Nets" paper which implements a U-Net like architecture with graph pooling and unpooling operations.

SchNet

The continuous-filter convolutional neural network SchNet from the "SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" paper that uses the interactions blocks of the form

DimeNet

The directional message passing neural network (DimeNet) from the "Directional Message Passing for Molecular Graphs" paper.

DimeNetPlusPlus

The DimeNet++ from the "Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules" paper.

to_captum

Alias for to_captum_model.

to_captum_model

Converts a model to a model that can be used for Captum.ai attribution methods.

to_captum_input

Given x, edge_index and mask_type, converts it to a format to use in Captum.ai attribution methods.

captum_output_to_dicts

Convert the output of Captum.ai attribution methods which is a tuple of attributions to two dictonaries with node and edge attribution tensors.

MetaPath2Vec

The MetaPath2Vec model from the "metapath2vec: Scalable Representation Learning for Heterogeneous Networks" paper where random walks based on a given metapath are sampled in a heterogeneous graph, and node embeddings are learned via negative sampling optimization.

DeepGCNLayer

The skip connection operations from the "DeepGCNs: Can GCNs Go as Deep as CNNs?" and "All You Need to Train Deeper GCNs" papers.

TGNMemory

The Temporal Graph Network (TGN) memory model from the "Temporal Graph Networks for Deep Learning on Dynamic Graphs" paper.

LabelPropagation

The label propagation operator from the "Learning from Labeled and Unlabeled Datawith Label Propagation" paper

CorrectAndSmooth

The correct and smooth (C&S) post-processing model from the "Combining Label Propagation And Simple Models Out-performs Graph Neural Networks" paper, where soft predictions \(\mathbf{Z}\) (obtained from a simple base predictor) are first corrected based on ground-truth training label information \(\mathbf{Y}\) and residual propagation

AttentiveFP

The Attentive FP model for molecular representation learning from the "Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism" paper, based on graph attention mechanisms.

RECT_L

The RECT model, i.e. its supervised RECT-L part, from the "Network Embedding with Completely-imbalanced Labels" paper.

LINKX

The LINKX model from the "Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods" paper

LightGCN

The LightGCN model from the "LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation" paper.

MaskLabel

The label embedding and masking layer from the "Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification" paper.

GroupAddRev

The Grouped Reversible GNN module from the "Graph Neural Networks with 1000 Layers" paper.

class MLP(channel_list: Optional[Union[int, List[int]]] = None, *, in_channels: Optional[int] = None, hidden_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: Optional[int] = None, dropout: Union[float, List[float]] = 0.0, act: Optional[Union[str, Callable]] = 'relu', act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[Union[str, Callable]] = 'batch_norm', norm_kwargs: Optional[Dict[str, Any]] = None, plain_last: bool = True, bias: Union[bool, List[bool]] = True, **kwargs)[source]

A Multi-Layer Perception (MLP) model. There exists two ways to instantiate an MLP:

  1. By specifying explicit channel sizes, e.g.,

    mlp = MLP([16, 32, 64, 128])
    

    creates a three-layer MLP with differently sized hidden layers.

  1. By specifying fixed hidden channel sizes over a number of layers, e.g.,

    mlp = MLP(in_channels=16, hidden_channels=32,
              out_channels=128, num_layers=3)
    

    creates a three-layer MLP with equally sized hidden layers.

Parameters
  • channel_list (List[int] or int, optional) – List of input, intermediate and output channels such that len(channel_list) - 1 denotes the number of layers of the MLP (default: None)

  • in_channels (int, optional) – Size of each input sample. Will override channel_list. (default: None)

  • hidden_channels (int, optional) – Size of each hidden sample. Will override channel_list. (default: None)

  • out_channels (int, optional) – Size of each output sample. Will override channel_list. (default: None)

  • num_layers (int, optional) – The number of layers. Will override channel_list. (default: None)

  • dropout (float or List[float], optional) – Dropout probability of each hidden embedding. If a list is provided, sets the dropout value per layer. (default: 0.)

  • act (str or Callable, optional) – The non-linear activation function to use. (default: "relu")

  • act_first (bool, optional) – If set to True, activation is applied before normalization. (default: False)

  • act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by act. (default: None)

  • norm (str or Callable, optional) – The normalization function to use. (default: "batch_norm")

  • norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by norm. (default: None)

  • plain_last (bool, optional) – If set to False, will apply non-linearity, batch normalization and dropout to the last layer as well. (default: True)

  • bias (bool or List[bool], optional) – If set to False, the module will not learn additive biases. If a list is provided, sets the bias per layer. (default: True)

  • **kwargs (optional) – Additional deprecated arguments of the MLP layer.

property in_channels: int

Size of each input sample.

property out_channels: int

Size of each output sample.

property num_layers: int

The number of layers.

reset_parameters()[source]
forward(x: Tensor, return_emb: Optional[Tensor] = None) Tensor[source]
class GCN(in_channels: int, hidden_channels: int, num_layers: int, out_channels: Optional[int] = None, dropout: float = 0.0, act: Optional[Union[str, Callable]] = 'relu', act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[Union[str, Callable]] = None, norm_kwargs: Optional[Dict[str, Any]] = None, jk: Optional[str] = None, **kwargs)[source]

The Graph Neural Network from the “Semi-supervised Classification with Graph Convolutional Networks” paper, using the GCNConv operator for message passing.

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • hidden_channels (int) – Size of each hidden sample.

  • num_layers (int) – Number of message passing layers.

  • out_channels (int, optional) – If not set to None, will apply a final linear transformation to convert hidden node embeddings to output size out_channels. (default: None)

  • dropout (float, optional) – Dropout probability. (default: 0.)

  • act (str or Callable, optional) – The non-linear activation function to use. (default: "relu")

  • act_first (bool, optional) – If set to True, activation is applied before normalization. (default: False)

  • act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by act. (default: None)

  • norm (str or Callable, optional) – The normalization function to use. (default: None)

  • norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by norm. (default: None)

  • jk (str, optional) – The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (None, "last", "cat", "max", "lstm"). (default: None)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.GCNConv.

supports_edge_weight = True
supports_edge_attr = False
class GraphSAGE(in_channels: int, hidden_channels: int, num_layers: int, out_channels: Optional[int] = None, dropout: float = 0.0, act: Optional[Union[str, Callable]] = 'relu', act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[Union[str, Callable]] = None, norm_kwargs: Optional[Dict[str, Any]] = None, jk: Optional[str] = None, **kwargs)[source]

The Graph Neural Network from the “Inductive Representation Learning on Large Graphs” paper, using the SAGEConv operator for message passing.

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • hidden_channels (int) – Size of each hidden sample.

  • num_layers (int) – Number of message passing layers.

  • out_channels (int, optional) – If not set to None, will apply a final linear transformation to convert hidden node embeddings to output size out_channels. (default: None)

  • dropout (float, optional) – Dropout probability. (default: 0.)

  • act (str or Callable, optional) – The non-linear activation function to use. (default: "relu")

  • act_first (bool, optional) – If set to True, activation is applied before normalization. (default: False)

  • act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by act. (default: None)

  • norm (str or Callable, optional) – The normalization function to use. (default: None)

  • norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by norm. (default: None)

  • jk (str, optional) – The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (None, "last", "cat", "max", "lstm"). (default: None)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.SAGEConv.

supports_edge_weight = False
supports_edge_attr = False
class GIN(in_channels: int, hidden_channels: int, num_layers: int, out_channels: Optional[int] = None, dropout: float = 0.0, act: Optional[Union[str, Callable]] = 'relu', act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[Union[str, Callable]] = None, norm_kwargs: Optional[Dict[str, Any]] = None, jk: Optional[str] = None, **kwargs)[source]

The Graph Neural Network from the “How Powerful are Graph Neural Networks?” paper, using the GINConv operator for message passing.

Parameters
  • in_channels (int) – Size of each input sample.

  • hidden_channels (int) – Size of each hidden sample.

  • num_layers (int) – Number of message passing layers.

  • out_channels (int, optional) – If not set to None, will apply a final linear transformation to convert hidden node embeddings to output size out_channels. (default: None)

  • dropout (float, optional) – Dropout probability. (default: 0.)

  • act (str or Callable, optional) – The non-linear activation function to use. (default: "relu")

  • act_first (bool, optional) – If set to True, activation is applied before normalization. (default: False)

  • act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by act. (default: None)

  • norm (str or Callable, optional) – The normalization function to use. (default: None)

  • norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by norm. (default: None)

  • jk (str, optional) – The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (None, "last", "cat", "max", "lstm"). (default: None)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.GINConv.

supports_edge_weight = False
supports_edge_attr = False
class GAT(in_channels: int, hidden_channels: int, num_layers: int, out_channels: Optional[int] = None, dropout: float = 0.0, act: Optional[Union[str, Callable]] = 'relu', act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[Union[str, Callable]] = None, norm_kwargs: Optional[Dict[str, Any]] = None, jk: Optional[str] = None, **kwargs)[source]

The Graph Neural Network from “Graph Attention Networks” or “How Attentive are Graph Attention Networks?” papers, using the GATConv or GATv2Conv operator for message passing, respectively.

Parameters
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • hidden_channels (int) – Size of each hidden sample.

  • num_layers (int) – Number of message passing layers.

  • out_channels (int, optional) – If not set to None, will apply a final linear transformation to convert hidden node embeddings to output size out_channels. (default: None)

  • v2 (bool, optional) – If set to True, will make use of GATv2Conv rather than GATConv. (default: False)

  • dropout (float, optional) – Dropout probability. (default: 0.)

  • act (str or Callable, optional) – The non-linear activation function to use. (default: "relu")

  • act_first (bool, optional) – If set to True, activation is applied before normalization. (default: False)

  • act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by act. (default: None)

  • norm (str or Callable, optional) – The normalization function to use. (default: None)

  • norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by norm. (default: None)

  • jk (str, optional) – The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (None, "last", "cat", "max", "lstm"). (default: None)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.GATConv or torch_geometric.nn.conv.GATv2Conv.

supports_edge_weight = False
supports_edge_attr = True
class PNA(in_channels: int, hidden_channels: int, num_layers: int, out_channels: Optional[int] = None, dropout: float = 0.0, act: Optional[Union[str, Callable]] = 'relu', act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[Union[str, Callable]] = None, norm_kwargs: Optional[Dict[str, Any]] = None, jk: Optional[str] = None, **kwargs)[source]

The Graph Neural Network from the “Principal Neighbourhood Aggregation for Graph Nets” paper, using the PNAConv operator for message passing.

Parameters
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • hidden_channels (int) – Size of each hidden sample.

  • num_layers (int) – Number of message passing layers.

  • out_channels (int, optional) – If not set to None, will apply a final linear transformation to convert hidden node embeddings to output size out_channels. (default: None)

  • dropout (float, optional) – Dropout probability. (default: 0.)

  • act (str or Callable, optional) – The non-linear activation function to use. (default: "relu")

  • act_first (bool, optional) – If set to True, activation is applied before normalization. (default: False)

  • act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by act. (default: None)

  • norm (str or Callable, optional) – The normalization function to use. (default: None)

  • norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by norm. (default: None)

  • jk (str, optional) – The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (None, "last", "cat", "max", "lstm"). (default: None)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.PNAConv.

supports_edge_weight = False
supports_edge_attr = True
class EdgeCNN(in_channels: int, hidden_channels: int, num_layers: int, out_channels: Optional[int] = None, dropout: float = 0.0, act: Optional[Union[str, Callable]] = 'relu', act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[Union[str, Callable]] = None, norm_kwargs: Optional[Dict[str, Any]] = None, jk: Optional[str] = None, **kwargs)[source]

The Graph Neural Network from the “Dynamic Graph CNN for Learning on Point Clouds” paper, using the EdgeConv operator for message passing.

Parameters
  • in_channels (int) – Size of each input sample.

  • hidden_channels (int) – Size of each hidden sample.

  • num_layers (int) – Number of message passing layers.

  • out_channels (int, optional) – If not set to None, will apply a final linear transformation to convert hidden node embeddings to output size out_channels. (default: None)

  • dropout (float, optional) – Dropout probability. (default: 0.)

  • act (str or Callable, optional) – The non-linear activation function to use. (default: "relu")

  • act_first (bool, optional) – If set to True, activation is applied before normalization. (default: False)

  • act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by act. (default: None)

  • norm (str or Callable, optional) – The normalization function to use. (default: None)

  • norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by norm. (default: None)

  • jk (str, optional) – The Jumping Knowledge mode. If specified, the model will additionally apply a final linear transformation to transform node embeddings to the expected output feature dimensionality. (None, "last", "cat", "max", "lstm"). (default: None)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.EdgeConv.

supports_edge_weight = False
supports_edge_attr = False
class JumpingKnowledge(mode: str, channels: Optional[int] = None, num_layers: Optional[int] = None)[source]

The Jumping Knowledge layer aggregation module from the “Representation Learning on Graphs with Jumping Knowledge Networks” paper based on either concatenation ("cat")

\[\mathbf{x}_v^{(1)} \, \Vert \, \ldots \, \Vert \, \mathbf{x}_v^{(T)}\]

max pooling ("max")

\[\max \left( \mathbf{x}_v^{(1)}, \ldots, \mathbf{x}_v^{(T)} \right)\]

or weighted summation

\[\sum_{t=1}^T \alpha_v^{(t)} \mathbf{x}_v^{(t)}\]

with attention scores \(\alpha_v^{(t)}\) obtained from a bi-directional LSTM ("lstm").

Parameters
  • mode (string) – The aggregation scheme to use ("cat", "max" or "lstm").

  • channels (int, optional) – The number of channels per representation. Needs to be only set for LSTM-style aggregation. (default: None)

  • num_layers (int, optional) – The number of layers to aggregate. Needs to be only set for LSTM-style aggregation. (default: None)

reset_parameters()[source]
forward(xs: List[Tensor]) Tensor[source]

Aggregates representations across different layers.

Parameters

xs (List[Tensor]) – List containing layer-wise representations.

class Node2Vec(edge_index: Tensor, embedding_dim: int, walk_length: int, context_size: int, walks_per_node: int = 1, p: float = 1.0, q: float = 1.0, num_negative_samples: int = 1, num_nodes: Optional[int] = None, sparse: bool = False)[source]

The Node2Vec model from the “node2vec: Scalable Feature Learning for Networks” paper where random walks of length walk_length are sampled in a given graph, and node embeddings are learned via negative sampling optimization.

Note

For an example of using Node2Vec, see examples/node2vec.py.

Parameters
  • edge_index (LongTensor) – The edge indices.

  • embedding_dim (int) – The size of each embedding vector.

  • walk_length (int) – The walk length.

  • context_size (int) – The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes.

  • walks_per_node (int, optional) – The number of walks to sample for each node. (default: 1)

  • p (float, optional) – Likelihood of immediately revisiting a node in the walk. (default: 1)

  • q (float, optional) – Control parameter to interpolate between breadth-first strategy and depth-first strategy (default: 1)

  • num_negative_samples (int, optional) – The number of negative samples to use for each positive sample. (default: 1)

  • num_nodes (int, optional) – The number of nodes. (default: None)

  • sparse (bool, optional) – If set to True, gradients w.r.t. to the weight matrix will be sparse. (default: False)

reset_parameters()[source]
forward(batch: Optional[Tensor] = None) Tensor[source]

Returns the embeddings for the nodes in batch.

loader(**kwargs) DataLoader[source]
pos_sample(batch: Tensor) Tensor[source]
neg_sample(batch: Tensor) Tensor[source]
sample(batch: Tensor) Tuple[Tensor, Tensor][source]
loss(pos_rw: Tensor, neg_rw: Tensor) Tensor[source]

Computes the loss given positive and negative random walks.

test(train_z: Tensor, train_y: Tensor, test_z: Tensor, test_y: Tensor, solver: str = 'lbfgs', multi_class: str = 'auto', *args, **kwargs) float[source]

Evaluates latent space quality via a logistic regression downstream task.

class DeepGraphInfomax(hidden_channels: int, encoder: Module, summary: Callable, corruption: Callable)[source]

The Deep Graph Infomax model from the “Deep Graph Infomax” paper based on user-defined encoder and summary model \(\mathcal{E}\) and \(\mathcal{R}\) respectively, and a corruption function \(\mathcal{C}\).

Parameters
  • hidden_channels (int) – The latent space dimensionality.

  • encoder (Module) – The encoder module \(\mathcal{E}\).

  • summary (callable) – The readout function \(\mathcal{R}\).

  • corruption (callable) – The corruption function \(\mathcal{C}\).

reset_parameters()[source]
forward(*args, **kwargs) Tuple[Tensor, Tensor, Tensor][source]

Returns the latent space for the input arguments, their corruptions and their summary representation.

discriminate(z: Tensor, summary: Tensor, sigmoid: bool = True) Tensor[source]

Given the patch-summary pair z and summary, computes the probability scores assigned to this patch-summary pair.

Parameters
  • z (Tensor) – The latent space.

  • summary (Tensor) – The summary vector.

  • sigmoid (bool, optional) – If set to False, does not apply the logistic sigmoid function to the output. (default: True)

loss(pos_z: Tensor, neg_z: Tensor, summary: Tensor) Tensor[source]

Computes the mutual information maximization objective.

test(train_z: Tensor, train_y: Tensor, test_z: Tensor, test_y: Tensor, solver: str = 'lbfgs', multi_class: str = 'auto', *args, **kwargs) float[source]

Evaluates latent space quality via a logistic regression downstream task.

class InnerProductDecoder[source]

The inner product decoder from the “Variational Graph Auto-Encoders” paper

\[\sigma(\mathbf{Z}\mathbf{Z}^{\top})\]

where \(\mathbf{Z} \in \mathbb{R}^{N \times d}\) denotes the latent space produced by the encoder.

forward(z: Tensor, edge_index: Tensor, sigmoid: bool = True) Tensor[source]

Decodes the latent variables z into edge probabilities for the given node-pairs edge_index.

Parameters
  • z (Tensor) – The latent space \(\mathbf{Z}\).

  • sigmoid (bool, optional) – If set to False, does not apply the logistic sigmoid function to the output. (default: True)

forward_all(z: Tensor, sigmoid: bool = True) Tensor[source]

Decodes the latent variables z into a probabilistic dense adjacency matrix.

Parameters
  • z (Tensor) – The latent space \(\mathbf{Z}\).

  • sigmoid (bool, optional) – If set to False, does not apply the logistic sigmoid function to the output. (default: True)

class GAE(encoder: Module, decoder: Optional[Module] = None)[source]

The Graph Auto-Encoder model from the “Variational Graph Auto-Encoders” paper based on user-defined encoder and decoder models.

Parameters
reset_parameters()[source]
encode(*args, **kwargs) Tensor[source]

Runs the encoder and computes node-wise latent variables.

decode(*args, **kwargs) Tensor[source]

Runs the decoder and computes edge probabilities.

recon_loss(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Optional[Tensor] = None) Tensor[source]

Given latent variables z, computes the binary cross entropy loss for positive edges pos_edge_index and negative sampled edges.

Parameters
  • z (Tensor) – The latent space \(\mathbf{Z}\).

  • pos_edge_index (LongTensor) – The positive edges to train against.

  • neg_edge_index (LongTensor, optional) – The negative edges to train against. If not given, uses negative sampling to calculate negative edges. (default: None)

test(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor) Tuple[Tensor, Tensor][source]

Given latent variables z, positive edges pos_edge_index and negative edges neg_edge_index, computes area under the ROC curve (AUC) and average precision (AP) scores.

Parameters
  • z (Tensor) – The latent space \(\mathbf{Z}\).

  • pos_edge_index (LongTensor) – The positive edges to evaluate against.

  • neg_edge_index (LongTensor) – The negative edges to evaluate against.

class VGAE(encoder: Module, decoder: Optional[Module] = None)[source]

The Variational Graph Auto-Encoder model from the “Variational Graph Auto-Encoders” paper.

Parameters
reparametrize(mu: Tensor, logstd: Tensor) Tensor[source]
encode(*args, **kwargs) Tensor[source]
kl_loss(mu: Optional[Tensor] = None, logstd: Optional[Tensor] = None) Tensor[source]

Computes the KL loss, either for the passed arguments mu and logstd, or based on latent variables from last encoding.

Parameters
  • mu (Tensor, optional) – The latent space for \(\mu\). If set to None, uses the last computation of \(mu\). (default: None)

  • logstd (Tensor, optional) – The latent space for \(\log\sigma\). If set to None, uses the last computation of \(\log\sigma^2\).(default: None)

class ARGA(encoder: Module, discriminator: Module, decoder: Optional[Module] = None)[source]

The Adversarially Regularized Graph Auto-Encoder model from the “Adversarially Regularized Graph Autoencoder for Graph Embedding” paper.

Parameters
reset_parameters()[source]
reg_loss(z: Tensor) Tensor[source]

Computes the regularization loss of the encoder.

Parameters

z (Tensor) – The latent space \(\mathbf{Z}\).

discriminator_loss(z: Tensor) Tensor[source]

Computes the loss of the discriminator.

Parameters

z (Tensor) – The latent space \(\mathbf{Z}\).

class ARGVA(encoder: Module, discriminator: Module, decoder: Optional[Module] = None)[source]

The Adversarially Regularized Variational Graph Auto-Encoder model from the “Adversarially Regularized Graph Autoencoder for Graph Embedding” paper.

Parameters
  • encoder (Module) – The encoder module to compute \(\mu\) and \(\log\sigma^2\).

  • discriminator (Module) – The discriminator module.

  • decoder (Module, optional) – The decoder module. If set to None, will default to the torch_geometric.nn.models.InnerProductDecoder. (default: None)

reparametrize(mu: Tensor, logstd: Tensor) Tensor[source]
encode(*args, **kwargs) Tensor[source]
kl_loss(mu: Optional[Tensor] = None, logstd: Optional[Tensor] = None) Tensor[source]
class SignedGCN(in_channels: int, hidden_channels: int, num_layers: int, lamb: float = 5, bias: bool = True)[source]

The signed graph convolutional network model from the “Signed Graph Convolutional Network” paper. Internally, this module uses the torch_geometric.nn.conv.SignedConv operator.

Parameters
  • in_channels (int) – Size of each input sample.

  • hidden_channels (int) – Size of each hidden sample.

  • num_layers (int) – Number of layers.

  • lamb (float, optional) – Balances the contributions of the overall objective. (default: 5)

  • bias (bool, optional) – If set to False, all layers will not learn an additive bias. (default: True)

reset_parameters()[source]
split_edges(edge_index: Tensor, test_ratio: float = 0.2) Tuple[Tensor, Tensor][source]

Splits the edges edge_index into train and test edges.

Parameters
  • edge_index (LongTensor) – The edge indices.

  • test_ratio (float, optional) – The ratio of test edges. (default: 0.2)

create_spectral_features(pos_edge_index: Tensor, neg_edge_index: Tensor, num_nodes: Optional[int] = None) Tensor[source]

Creates in_channels spectral node features based on positive and negative edges.

Parameters
  • pos_edge_index (LongTensor) – The positive edge indices.

  • neg_edge_index (LongTensor) – The negative edge indices.

  • num_nodes (int, optional) – The number of nodes, i.e. max_val + 1 of pos_edge_index and neg_edge_index. (default: None)

forward(x: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor) Tensor[source]

Computes node embeddings z based on positive edges pos_edge_index and negative edges neg_edge_index.

Parameters
  • x (Tensor) – The input node features.

  • pos_edge_index (LongTensor) – The positive edge indices.

  • neg_edge_index (LongTensor) – The negative edge indices.

discriminate(z: Tensor, edge_index: Tensor) Tensor[source]

Given node embeddings z, classifies the link relation between node pairs edge_index to be either positive, negative or non-existent.

Parameters
  • x (Tensor) – The input node features.

  • edge_index (LongTensor) – The edge indices.

nll_loss(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor) Tensor[source]

Computes the discriminator loss based on node embeddings z, and positive edges pos_edge_index and negative nedges neg_edge_index.

Parameters
  • z (Tensor) – The node embeddings.

  • pos_edge_index (LongTensor) – The positive edge indices.

  • neg_edge_index (LongTensor) – The negative edge indices.

pos_embedding_loss(z: Tensor, pos_edge_index: Tensor) Tensor[source]

Computes the triplet loss between positive node pairs and sampled non-node pairs.

Parameters
  • z (Tensor) – The node embeddings.

  • pos_edge_index (LongTensor) – The positive edge indices.

neg_embedding_loss(z: Tensor, neg_edge_index: Tensor) Tensor[source]

Computes the triplet loss between negative node pairs and sampled non-node pairs.

Parameters
  • z (Tensor) – The node embeddings.

  • neg_edge_index (LongTensor) – The negative edge indices.

loss(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor) Tensor[source]

Computes the overall objective.

Parameters
  • z (Tensor) – The node embeddings.

  • pos_edge_index (LongTensor) – The positive edge indices.

  • neg_edge_index (LongTensor) – The negative edge indices.

test(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor) Tuple[float, float][source]

Evaluates node embeddings z on positive and negative test edges by computing AUC and F1 scores.

Parameters
  • z (Tensor) – The node embeddings.

  • pos_edge_index (LongTensor) – The positive edge indices.

  • neg_edge_index (LongTensor) – The negative edge indices.

class RENet(num_nodes: int, num_rels: int, hidden_channels: int, seq_len: int, num_layers: int = 1, dropout: float = 0.0, bias: bool = True)[source]

The Recurrent Event Network model from the “Recurrent Event Network for Reasoning over Temporal Knowledge Graphs” paper

\[f_{\mathbf{\Theta}}(\mathbf{e}_s, \mathbf{e}_r, \mathbf{h}^{(t-1)}(s, r))\]

based on a RNN encoder

\[\mathbf{h}^{(t)}(s, r) = \textrm{RNN}(\mathbf{e}_s, \mathbf{e}_r, g(\mathcal{O}^{(t)}_r(s)), \mathbf{h}^{(t-1)}(s, r))\]

where \(\mathbf{e}_s\) and \(\mathbf{e}_r\) denote entity and relation embeddings, and \(\mathcal{O}^{(t)}_r(s)\) represents the set of objects interacted with subject \(s\) under relation \(r\) at timestamp \(t\). This model implements \(g\) as the Mean Aggregator and \(f_{\mathbf{\Theta}}\) as a linear projection.

Parameters
  • num_nodes (int) – The number of nodes in the knowledge graph.

  • num_rels (int) – The number of relations in the knowledge graph.

  • hidden_channels (int) – Hidden size of node and relation embeddings.

  • seq_len (int) – The sequence length of past events.

  • num_layers (int, optional) – The number of recurrent layers. (default: 1)

  • dropout (float) – If non-zero, introduces a dropout layer before the final prediction. (default: 0.)

  • bias (bool, optional) – If set to False, all layers will not learn an additive bias. (default: True)

reset_parameters()[source]
static pre_transform(seq_len: int) Callable[source]

Precomputes history objects

\[\{ \mathcal{O}^{(t-k-1)}_r(s), \ldots, \mathcal{O}^{(t-1)}_r(s) \}\]

of a torch_geometric.datasets.icews.EventDataset with \(k\) denoting the sequence length seq_len.

forward(data: Data) Tuple[Tensor, Tensor][source]

Given a data batch, computes the forward pass.

Parameters

data (torch_geometric.data.Data) – The input data, holding subject sub, relation rel and object obj information with shape [batch_size]. In addition, data needs to hold history information for subjects, given by a vector of node indices h_sub and their relative timestamps h_sub_t and batch assignments h_sub_batch. The same information must be given for objects (h_obj, h_obj_t, h_obj_batch).

test(logits: Tensor, y: Tensor) Tensor[source]

Given ground-truth y, computes Mean Reciprocal Rank (MRR) and Hits at 1/3/10.

class GraphUNet(in_channels: int, hidden_channels: int, out_channels: int, depth: int, pool_ratios: ~typing.Union[float, ~typing.List[float]] = 0.5, sum_res: bool = True, act: ~typing.Callable = <function relu>)[source]

The Graph U-Net model from the “Graph U-Nets” paper which implements a U-Net like architecture with graph pooling and unpooling operations.

Parameters
  • in_channels (int) – Size of each input sample.

  • hidden_channels (int) – Size of each hidden sample.

  • out_channels (int) – Size of each output sample.

  • depth (int) – The depth of the U-Net architecture.

  • pool_ratios (float or [float], optional) – Graph pooling ratio for each depth. (default: 0.5)

  • sum_res (bool, optional) – If set to False, will use concatenation for integration of skip connections instead summation. (default: True)

  • act (torch.nn.functional, optional) – The nonlinearity to use. (default: torch.nn.functional.relu)

reset_parameters()[source]
forward(x: Tensor, edge_index: Tensor, batch: Optional[Tensor] = None) Tensor[source]
augment_adj(edge_index: Tensor, edge_weight: Tensor, num_nodes: int) Tuple[Tensor, Tensor][source]
class SchNet(hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, interaction_graph: Optional[Callable] = None, max_num_neighbors: int = 32, readout: str = 'add', dipole: bool = False, mean: Optional[float] = None, std: Optional[float] = None, atomref: Optional[Tensor] = None)[source]

The continuous-filter convolutional neural network SchNet from the “SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions” paper that uses the interactions blocks of the form

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),\]

here \(h_{\mathbf{\Theta}}\) denotes an MLP and \(\mathbf{e}_{j,i}\) denotes the interatomic distances between atoms.

Note

For an example of using a pretrained SchNet variant, see examples/qm9_pretrained_schnet.py.

Parameters
  • hidden_channels (int, optional) – Hidden embedding size. (default: 128)

  • num_filters (int, optional) – The number of filters to use. (default: 128)

  • num_interactions (int, optional) – The number of interaction blocks. (default: 6)

  • num_gaussians (int, optional) – The number of gaussians \(\mu\). (default: 50)

  • interaction_graph (Callable, optional) – The function used to compute the pairwise interaction graph and interatomic distances. If set to None, will construct a graph based on cutoff and max_num_neighbors properties. If provided, this method takes in pos and batch tensors and should return (edge_index, edge_weight) tensors. (default None)

  • cutoff (float, optional) – Cutoff distance for interatomic interactions. (default: 10.0)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 32)

  • readout (string, optional) – Whether to apply "add" or "mean" global aggregation. (default: "add")

  • dipole (bool, optional) – If set to True, will use the magnitude of the dipole moment to make the final prediction, e.g., for target 0 of torch_geometric.datasets.QM9. (default: False)

  • mean (float, optional) – The mean of the property to predict. (default: None)

  • std (float, optional) – The standard deviation of the property to predict. (default: None)

  • atomref (torch.Tensor, optional) – The reference of single-atom properties. Expects a vector of shape (max_atomic_number, ).

url = 'http://www.quantum-machine.org/datasets/trained_schnet_models.zip'
reset_parameters()[source]
static from_qm9_pretrained(root: str, dataset: Dataset, target: int) Tuple[SchNet, Dataset, Dataset, Dataset][source]
forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor[source]
Parameters
  • z (LongTensor) – Atomic number of each atom with shape [num_atoms].

  • pos (Tensor) – Coordinates of each atom with shape [num_atoms, 3].

  • batch (LongTensor, optional) – Batch indices assigning each atom to a separate molecule with shape [num_atoms]. (default: None)

class DimeNet(hidden_channels: int, out_channels: int, num_blocks: int, num_bilinear: int, num_spherical: int, num_radial, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish')[source]

The directional message passing neural network (DimeNet) from the “Directional Message Passing for Molecular Graphs” paper. DimeNet transforms messages based on the angle between them in a rotation-equivariant fashion.

Note

For an example of using a pretrained DimeNet variant, see examples/qm9_pretrained_dimenet.py.

Parameters
  • hidden_channels (int) – Hidden embedding size.

  • out_channels (int) – Size of each output sample.

  • num_blocks (int) – Number of building blocks.

  • num_bilinear (int) – Size of the bilinear layer tensor.

  • num_spherical (int) – Number of spherical harmonics.

  • num_radial (int) – Number of radial basis functions.

  • cutoff (float, optional) – Cutoff distance for interatomic interactions. (default: 5.0)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 32)

  • envelope_exponent (int, optional) – Shape of the smooth cutoff. (default: 5)

  • num_before_skip (int, optional) – Number of residual layers in the interaction blocks before the skip connection. (default: 1)

  • num_after_skip (int, optional) – Number of residual layers in the interaction blocks after the skip connection. (default: 2)

  • num_output_layers (int, optional) – Number of linear layers for the output blocks. (default: 3)

  • act (str or Callable, optional) – The activation function. (default: "swish")

url = 'https://github.com/klicperajo/dimenet/raw/master/pretrained/dimenet'
reset_parameters()[source]
classmethod from_qm9_pretrained(root: str, dataset: Dataset, target: int) Tuple[DimeNet, Dataset, Dataset, Dataset][source]
triplets(edge_index: Tensor, num_nodes: int) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor][source]
forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor[source]
class DimeNetPlusPlus(hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish')[source]

The DimeNet++ from the “Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules” paper.

DimeNetPlusPlus is an upgrade to the DimeNet model with 8x faster and 10% more accurate than DimeNet.

Parameters
  • hidden_channels (int) – Hidden embedding size.

  • out_channels (int) – Size of each output sample.

  • num_blocks (int) – Number of building blocks.

  • int_emb_size (int) – Size of embedding in the interaction block.

  • basis_emb_size (int) – Size of basis embedding in the interaction block.

  • out_emb_channels (int) – Size of embedding in the output block.

  • num_spherical (int) – Number of spherical harmonics.

  • num_radial (int) – Number of radial basis functions.

  • cutoff – (float, optional): Cutoff distance for interatomic interactions. (default: 5.0)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 32)

  • envelope_exponent (int, optional) – Shape of the smooth cutoff. (default: 5)

  • num_before_skip – (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: 1)

  • num_after_skip – (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default: 2)

  • num_output_layers – (int, optional): Number of linear layers for the output blocks. (default: 3)

  • act – (str or Callable, optional): The activation funtion. (default: "swish")

url = 'https://raw.githubusercontent.com/gasteigerjo/dimenet/master/pretrained/dimenet_pp'
classmethod from_qm9_pretrained(root: str, dataset: Dataset, target: int) Tuple[DimeNetPlusPlus, Dataset, Dataset, Dataset][source]
to_captum(model: Module, mask_type: str = 'edge', output_idx: Optional[int] = None, metadata: Optional[Tuple[List[str], List[Tuple[str, str, str]]]] = None) Union[CaptumModel, CaptumHeteroModel][source]

Alias for to_captum_model.

Warning

to_captum is deprecated and will be removed in a future release. Use torch_geometric.nn.to_captum_model instead.

to_captum_model(model: Module, mask_type: str = 'edge', output_idx: Optional[int] = None, metadata: Optional[Tuple[List[str], List[Tuple[str, str, str]]]] = None) Union[CaptumModel, CaptumHeteroModel][source]

Converts a model to a model that can be used for Captum.ai attribution methods.

Sample code for homogenous graphs:

from captum.attr import IntegratedGradients

from torch_geometric.data import Data
from torch_geometric.nn import GCN
from torch_geometric.nn import to_captum_model, to_captum_input

data = Data(x=(...), edge_index(...))
model = GCN(...)
...  # Train the model.

# Explain predictions for node `10`:
mask_type="edge"
output_idx = 10
captum_model = to_captum_model(model, mask_type, output_idx)
inputs, additional_forward_args = to_captum_input(data.x,
                                    data.edge_index,mask_type)

ig = IntegratedGradients(captum_model)
ig_attr = ig.attribute(inputs = inputs,
                       target=int(y[output_idx]),
                       additional_forward_args=additional_forward_args,
                       internal_batch_size=1)

Sample code for heterogenous graphs:

from captum.attr import IntegratedGradients

from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv
from torch_geometric.nn import (captum_output_to_dicts,
                                to_captum_model, to_captum_input)

data = HeteroData(...)
model = HeteroConv(...)
...  # Train the model.

# Explain predictions for node `10`:
mask_type="edge"
metadata = data.metadata
output_idx = 10
captum_model = to_captum_model(model, mask_type, output_idx, metadata)
inputs, additional_forward_args = to_captum_input(data.x_dict,
                                    data.edge_index_dict, mask_type)

ig = IntegratedGradients(captum_model)
ig_attr = ig.attribute(inputs=inputs,
                       target=int(y[output_idx]),
                       additional_forward_args=additional_forward_args,
                       internal_batch_size=1)
edge_attr_dict = captum_output_to_dicts(ig_attr, mask_type, metadata)

Note

For an example of using a Captum attribution method within PyG, see examples/captum_explainability.py.

Parameters
  • model (torch.nn.Module) – The model to be explained.

  • mask_type (str, optional) – Denotes the type of mask to be created with a Captum explainer. Valid inputs are "edge", "node", and "node_and_edge". (default: "edge")

  • output_idx (int, optional) – Index of the output element (node or link index) to be explained. With output_idx set, the forward function will return the output of the model for the element at the index specified. (default: None)

  • metadata (Metadata, optional) – The metadata of the heterogeneous graph. Only required if explaning over a HeteroData object. (default: :obj: None)

to_captum_input(x: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], mask_type: str, *args) Tuple[Tuple[Tensor], Tuple[Tensor]][source]

Given x, edge_index and mask_type, converts it to a format to use in Captum.ai attribution methods. Returns inputs and additional_forward_args required for Captum’s attribute functions. See torch_geometric.nn.to_captum_model for example usage.

Parameters
  • x (Tensor or Dict[NodeType, Tensor]) – The node features. For heterogenous graphs this is a dictionary holding node featues for each node type.

  • edge_index (Tensor or Dict[EdgeType, Tensor]) – The edge indicies. For heterogenous graphs this is a dictionary holding edge index for each edge type.

  • mask_type (str) – Denotes the type of mask to be created with a Captum explainer. Valid inputs are "edge", "node", and "node_and_edge":

  • *args – Additional forward arguments of the model being explained which will be added to additonal_forward_args. For Data this is arguments other than x and edge_index. For HeteroData this is arguments other than x_dict and edge_index_dict.

captum_output_to_dicts(captum_attrs: Tuple[Tensor], mask_type: str, metadata: Tuple[List[str], List[Tuple[str, str, str]]]) Tuple[Optional[Dict[str, Tensor]], Optional[Dict[Tuple[str, str, str], Tensor]]][source]

Convert the output of Captum.ai attribution methods which is a tuple of attributions to two dictonaries with node and edge attribution tensors. This function is used while explaining HeteroData objects. See torch_geometric.nn.to_captum_model for example usage.

Parameters
  • captum_attrs (tuple[tensor]) – The output of attribution methods.

  • mask_type (str) –

    Denotes the type of mask to be created with a Captum explainer. Valid inputs are "edge", "node", and "node_and_edge":

    1. "edge": captum_attrs contains only edge attributions. The returned tuple has no node attributions and a edge attribution dictionary with key EdgeType and value edge mask tensor of shape [num_edges].

    2. "node": captum_attrs contains only node attributions. The returned tuple has node attribution dictonary with key NodeType and value node mask tensor of shape [num_nodes, num_features] and no edge attribution.

    3. "node_and_edge": captum_attrs contains only node attributions. The returned tuple contains node attribution dictionary followed by edge attribution dictionary.

  • metadata (Metadata) – The metadata of the heterogeneous graph.

class MetaPath2Vec(edge_index_dict: Dict[Tuple[str, str, str], Tensor], embedding_dim: int, metapath: List[Tuple[str, str, str]], walk_length: int, context_size: int, walks_per_node: int = 1, num_negative_samples: int = 1, num_nodes_dict: Optional[Dict[str, int]] = None, sparse: bool = False)[source]

The MetaPath2Vec model from the “metapath2vec: Scalable Representation Learning for Heterogeneous Networks” paper where random walks based on a given metapath are sampled in a heterogeneous graph, and node embeddings are learned via negative sampling optimization.

Note

For an example of using MetaPath2Vec, see examples/hetero/metapath2vec.py.

Parameters
  • edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – Dictionary holding edge indices for each (src_node_type, rel_type, dst_node_type) present in the heterogeneous graph.

  • embedding_dim (int) – The size of each embedding vector.

  • metapath (List[Tuple[str, str, str]]) – The metapath described as a list of (src_node_type, rel_type, dst_node_type) tuples.

  • walk_length (int) – The walk length.

  • context_size (int) – The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes.

  • walks_per_node (int, optional) – The number of walks to sample for each node. (default: 1)

  • num_negative_samples (int, optional) – The number of negative samples to use for each positive sample. (default: 1)

  • num_nodes_dict (Dict[str, int], optional) – Dictionary holding the number of nodes for each node type. (default: None)

  • sparse (bool, optional) – If set to True, gradients w.r.t. to the weight matrix will be sparse. (default: False)

reset_parameters()[source]
forward(node_type: str, batch: Optional[Tensor] = None) Tensor[source]

Returns the embeddings for the nodes in batch of type node_type.

loader(**kwargs)[source]

Returns the data loader that creates both positive and negative random walks on the heterogeneous graph.

Parameters

**kwargs (optional) – Arguments of torch.utils.data.DataLoader, such as batch_size, shuffle, drop_last or num_workers.

loss(pos_rw: Tensor, neg_rw: Tensor) Tensor[source]

Computes the loss given positive and negative random walks.

test(train_z: Tensor, train_y: Tensor, test_z: Tensor, test_y: Tensor, solver: str = 'lbfgs', multi_class: str = 'auto', *args, **kwargs) float[source]

Evaluates latent space quality via a logistic regression downstream task.

class DeepGCNLayer(conv: Optional[Module] = None, norm: Optional[Module] = None, act: Optional[Module] = None, block: str = 'res+', dropout: float = 0.0, ckpt_grad: bool = False)[source]

The skip connection operations from the “DeepGCNs: Can GCNs Go as Deep as CNNs?” and “All You Need to Train Deeper GCNs” papers. The implemented skip connections includes the pre-activation residual connection ("res+"), the residual connection ("res"), the dense connection ("dense") and no connections ("plain").

  • Res+ ("res+"):

\[\text{Normalization}\to\text{Activation}\to\text{Dropout}\to \text{GraphConv}\to\text{Res}\]
  • Res ("res") / Dense ("dense") / Plain ("plain"):

\[\text{GraphConv}\to\text{Normalization}\to\text{Activation}\to \text{Res/Dense/Plain}\to\text{Dropout}\]

Note

For an example of using GENConv, see examples/ogbn_proteins_deepgcn.py.

Parameters
  • conv (torch.nn.Module, optional) – the GCN operator. (default: None)

  • norm (torch.nn.Module) – the normalization layer. (default: None)

  • act (torch.nn.Module) – the activation layer. (default: None)

  • block (string, optional) – The skip connection operation to use ("res+", "res", "dense" or "plain"). (default: "res+")

  • dropout (float, optional) – Whether to apply or dropout. (default: 0.)

  • ckpt_grad (bool, optional) – If set to True, will checkpoint this part of the model. Checkpointing works by trading compute for memory, since intermediate activations do not need to be kept in memory. Set this to True in case you encounter out-of-memory errors while going deep. (default: False)

reset_parameters()[source]
forward(*args, **kwargs) Tensor[source]
class TGNMemory(num_nodes: int, raw_msg_dim: int, memory_dim: int, time_dim: int, message_module: Callable, aggregator_module: Callable)[source]

The Temporal Graph Network (TGN) memory model from the “Temporal Graph Networks for Deep Learning on Dynamic Graphs” paper.

Note

For an example of using TGN, see examples/tgn.py.

Parameters
  • num_nodes (int) – The number of nodes to save memories for.

  • raw_msg_dim (int) – The raw message dimensionality.

  • memory_dim (int) – The hidden memory dimensionality.

  • time_dim (int) – The time encoding dimensionality.

  • message_module (torch.nn.Module) – The message function which combines source and destination node memory embeddings, the raw message and the time encoding.

  • aggregator_module (torch.nn.Module) – The message aggregator function which aggregates messages to the same destination into a single representation.

reset_parameters()[source]
reset_state()[source]

Resets the memory to its initial state.

detach()[source]

Detachs the memory from gradient computation.

forward(n_id: Tensor) Tuple[Tensor, Tensor][source]

Returns, for all nodes n_id, their current memory and their last updated timestamp.

update_state(src, dst, t, raw_msg)[source]

Updates the memory with newly encountered interactions (src, dst, t, raw_msg).

train(mode: bool = True)[source]

Sets the module in training mode.

class LabelPropagation(num_layers: int, alpha: float)[source]

The label propagation operator from the “Learning from Labeled and Unlabeled Datawith Label Propagation” paper

\[\mathbf{Y}^{\prime} = \alpha \cdot \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \mathbf{Y} + (1 - \alpha) \mathbf{Y},\]

where unlabeled data is inferred by labeled data via propagation.

Parameters
  • num_layers (int) – The number of propagations.

  • alpha (float) – The \(\alpha\) coefficient.

forward(y: ~torch.Tensor, edge_index: ~typing.Union[~torch.Tensor, ~torch_sparse.tensor.SparseTensor], mask: ~typing.Optional[~torch.Tensor] = None, edge_weight: ~typing.Optional[~torch.Tensor] = None, post_step: ~typing.Callable = <function LabelPropagation.<lambda>>) Tensor[source]
class CorrectAndSmooth(num_correction_layers: int, correction_alpha: float, num_smoothing_layers: int, smoothing_alpha: float, autoscale: bool = True, scale: float = 1.0)[source]

The correct and smooth (C&S) post-processing model from the “Combining Label Propagation And Simple Models Out-performs Graph Neural Networks” paper, where soft predictions \(\mathbf{Z}\) (obtained from a simple base predictor) are first corrected based on ground-truth training label information \(\mathbf{Y}\) and residual propagation

\[\begin{split}\mathbf{e}^{(0)}_i &= \begin{cases} \mathbf{y}_i - \mathbf{z}_i, & \text{if }i \text{ is training node,}\\ \mathbf{0}, & \text{else} \end{cases}\end{split}\]
\[ \begin{align}\begin{aligned}\mathbf{E}^{(\ell)} &= \alpha_1 \mathbf{D}^{-1/2}\mathbf{A} \mathbf{D}^{-1/2} \mathbf{E}^{(\ell - 1)} + (1 - \alpha_1) \mathbf{E}^{(\ell - 1)}\\\mathbf{\hat{Z}} &= \mathbf{Z} + \gamma \cdot \mathbf{E}^{(L_1)},\end{aligned}\end{align} \]

where \(\gamma\) denotes the scaling factor (either fixed or automatically determined), and then smoothed over the graph via label propagation

\[\begin{split}\mathbf{\hat{z}}^{(0)}_i &= \begin{cases} \mathbf{y}_i, & \text{if }i\text{ is training node,}\\ \mathbf{\hat{z}}_i, & \text{else} \end{cases}\end{split}\]
\[\mathbf{\hat{Z}}^{(\ell)} = \alpha_2 \mathbf{D}^{-1/2}\mathbf{A} \mathbf{D}^{-1/2} \mathbf{\hat{Z}}^{(\ell - 1)} + (1 - \alpha_1) \mathbf{\hat{Z}}^{(\ell - 1)}\]

to obtain the final prediction \(\mathbf{\hat{Z}}^{(L_2)}\).

Note

For an example of using the C&S model, see examples/correct_and_smooth.py.

Parameters
  • num_correction_layers (int) – The number of propagations \(L_1\).

  • correction_alpha (float) – The \(\alpha_1\) coefficient.

  • num_smoothing_layers (int) – The number of propagations \(L_2\).

  • smoothing_alpha (float) – The \(\alpha_2\) coefficient.

  • autoscale (bool, optional) – If set to True, will automatically determine the scaling factor \(\gamma\). (default: True)

  • scale (float, optional) – The scaling factor \(\gamma\), in case autoscale = False. (default: 1.0)

correct(y_soft: Tensor, y_true: Tensor, mask: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
Parameters
  • y_soft (Tensor) – The soft predictions \(\mathbf{Z}\) obtained from a simple base predictor.

  • y_true (Tensor) – The ground-truth label information \(\mathbf{Y}\) of training nodes.

  • mask (LongTensor or BoolTensor) – A mask or index tensor denoting which nodes were used for training.

  • edge_index (Tensor or SparseTensor) – The edge connectivity.

  • edge_weight (Tensor, optional) – The edge weights. (default: None)

smooth(y_soft: Tensor, y_true: Tensor, mask: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
class AttentiveFP(in_channels: int, hidden_channels: int, out_channels: int, edge_dim: int, num_layers: int, num_timesteps: int, dropout: float = 0.0)[source]

The Attentive FP model for molecular representation learning from the “Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism” paper, based on graph attention mechanisms.

Parameters
  • in_channels (int) – Size of each input sample.

  • hidden_channels (int) – Hidden node feature dimensionality.

  • out_channels (int) – Size of each output sample.

  • edge_dim (int) – Edge feature dimensionality.

  • num_layers (int) – Number of GNN layers.

  • num_timesteps (int) – Number of iterative refinement steps for global readout.

  • dropout (float, optional) – Dropout probability. (default: 0.0)

reset_parameters() None[source]
forward(x: Tensor, edge_index: Tensor, edge_attr: Tensor, batch: Tensor) Tensor[source]
jittable() AttentiveFP[source]
class RECT_L(in_channels: int, hidden_channels: int, normalize: bool = True, dropout: float = 0.0)[source]

The RECT model, i.e. its supervised RECT-L part, from the “Network Embedding with Completely-imbalanced Labels” paper. In particular, a GCN model is trained that reconstructs semantic class knowledge.

Note

For an example of using RECT, see examples/rect.py.

Parameters
  • in_channels (int) – Size of each input sample.

  • hidden_channels (int) – Intermediate size of each sample.

  • normalize (bool, optional) – Whether to add self-loops and compute symmetric normalization coefficients on the fly. (default: True)

  • dropout (float, optional) – The dropout probability. (default: 0.0)

reset_parameters()[source]
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
embed(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
get_semantic_labels(x: Tensor, y: Tensor, mask: Tensor) Tensor[source]

Replaces the original labels by their class-centers.

class LINKX(num_nodes: int, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int, num_edge_layers: int = 1, num_node_layers: int = 1, dropout: float = 0.0)[source]

The LINKX model from the “Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods” paper

\[ \begin{align}\begin{aligned}\mathbf{H}_{\mathbf{A}} &= \textrm{MLP}_{\mathbf{A}}(\mathbf{A})\\\mathbf{H}_{\mathbf{X}} &= \textrm{MLP}_{\mathbf{X}}(\mathbf{X})\\\mathbf{Y} &= \textrm{MLP}_{f} \left( \sigma \left( \mathbf{W} [\mathbf{H}_{\mathbf{A}}, \mathbf{H}_{\mathbf{X}}] + \mathbf{H}_{\mathbf{A}} + \mathbf{H}_{\mathbf{X}} \right) \right)\end{aligned}\end{align} \]

Note

For an example of using LINKX, see examples/linkx.py.

Parameters
  • num_nodes (int) – The number of nodes in the graph.

  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • hidden_channels (int) – Size of each hidden sample.

  • out_channels (int) – Size of each output sample.

  • num_layers (int) – Number of layers of \(\textrm{MLP}_{f}\).

  • num_edge_layers (int) – Number of layers of \(\textrm{MLP}_{\mathbf{A}}\). (default: 1)

  • num_node_layers (int) – Number of layers of \(\textrm{MLP}_{\mathbf{X}}\). (default: 1)

  • dropout (float, optional) – Dropout probability of each hidden embedding. (default: 0.)

reset_parameters()[source]
forward(x: Optional[Tensor], edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
class LightGCN(num_nodes: int, embedding_dim: int, num_layers: int, alpha: Optional[Union[float, Tensor]] = None, **kwargs)[source]

The LightGCN model from the “LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation” paper.

LightGCN learns embeddings by linearly propagating them on the underlying graph, and uses the weighted sum of the embeddings learned at all layers as the final embedding

\[\textbf{x}_i = \sum_{l=0}^{L} \alpha_l \textbf{x}^{(l)}_i,\]

where each layer’s embedding is computed as

\[\mathbf{x}^{(l+1)}_i = \sum_{j \in \mathcal{N}(i)} \frac{1}{\sqrt{\deg(i)\deg(j)}}\mathbf{x}^{(l)}_j.\]

Two prediction heads and trainign objectives are provided: link prediction (via link_pred_loss() and predict_link()) and recommendation (via recommendation_loss() and recommend()).

Note

Embeddings are propagated according to the graph connectivity specified by edge_index while rankings or link probabilities are computed according to the edges specified by edge_label_index.

Parameters
  • num_nodes (int) – The number of nodes in the graph.

  • embedding_dim (int) – The dimensionality of node embeddings.

  • num_layers (int) – The number of LGConv layers.

  • alpha (float or Tensor, optional) – The scalar or vector specifying the re-weighting coefficients for aggregating the final embedding. If set to None, the uniform initialization of 1 / (num_layers + 1) is used. (default: None)

  • **kwargs (optional) – Additional arguments of the underlying LGConv layers.

reset_parameters()[source]
get_embedding(edge_index: Union[Tensor, SparseTensor]) Tensor[source]
forward(edge_index: Union[Tensor, SparseTensor], edge_label_index: Optional[Tensor] = None) Tensor[source]

Computes rankings for pairs of nodes.

Parameters
  • edge_index (Tensor or SparseTensor) – Edge tensor specifying the connectivity of the graph.

  • edge_label_index (Tensor, optional) – Edge tensor specifying the node pairs for which to compute rankings or probabilities. If edge_label_index is set to None, all edges in edge_index will be used instead. (default: None)

Predict links between nodes specified in edge_label_index.

Parameters

prob (bool) – Whether probabilities should be returned. (default: False)

recommend(edge_index: Union[Tensor, SparseTensor], src_index: Optional[Tensor] = None, dst_index: Optional[Tensor] = None, k: int = 1) Tensor[source]

Get top-\(k\) recommendations for nodes in src_index.

Parameters
  • src_index (Tensor, optional) – Node indices for which recommendations should be generated. If set to None, all nodes will be used. (default: None)

  • dst_index (Tensor, optional) – Node indices which represent the possible recommendation choices. If set to None, all nodes will be used. (default: None)

  • k (int, optional) – Number of recommendations. (default: 1)

Computes the model loss for a link prediction objective via the torch.nn.BCEWithLogitsLoss.

Parameters
  • pred (Tensor) – The predictions.

  • edge_label (Tensor) – The ground-truth edge labels.

  • **kwargs (optional) – Additional arguments of the underlying torch.nn.BCEWithLogitsLoss loss function.

recommendation_loss(pos_edge_rank: Tensor, neg_edge_rank: Tensor, lambda_reg: float = 0.0001, **kwargs) Tensor[source]

Computes the model loss for a ranking objective via the Bayesian Personalized Ranking (BPR) loss.

Note

The i-th entry in the pos_edge_rank vector and i-th entry in the neg_edge_rank entry must correspond to ranks of positive and negative edges of the same entity (e.g., user).

Parameters
  • pos_edge_rank (Tensor) – Positive edge rankings.

  • neg_edge_rank (Tensor) – Negative edge rankings.

  • lambda_reg (int, optional) – The \(L_2\) regularization strength of the Bayesian Personalized Ranking (BPR) loss. (default: 1e-4)

  • **kwargs (optional) – Additional arguments of the underlying torch_geometric.nn.models.lightgcn.BPRLoss loss function.

class MaskLabel(num_classes: int, out_channels: int, method: str = 'add')[source]

The label embedding and masking layer from the “Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification” paper.

Here, node labels y are merged to the initial node features x for a subset of their nodes according to mask.

Note

For an example of using MaskLabel, see examples/unimp_arxiv.py.

Parameters
  • num_classes (int) – The number of classes.

  • out_channels (int) – Size of each output sample.

  • method (str, optional) – If set to "add", label embeddings are added to the input. If set to "concat", label embeddings are concatenated. In case method="add", then out_channels needs to be identical to the input dimensionality of node features. (default: "add")

reset_parameters()[source]
forward(x: Tensor, y: Tensor, mask: Tensor) Tensor[source]
static ratio_mask(mask: Tensor, ratio: float)[source]

Modifies mask by setting ratio of True entries to False. Does not operate in-place.

Parameters
  • mask (torch.Tensor) – The mask to re-mask.

  • ratio (float) – The ratio of entries to keep.

class GroupAddRev(conv: Union[Module, ModuleList], split_dim: int = -1, num_groups: Optional[int] = None, disable: bool = False, num_bwd_passes: int = 1)[source]

The Grouped Reversible GNN module from the “Graph Neural Networks with 1000 Layers” paper. This module enables training of arbitary deep GNNs with a memory complexity independent of the number of layers.

It does so by partitioning input node features \(\mathbf{X}\) into \(C\) groups across the feature dimension. Then, a grouped reversible GNN block \(f_{\theta(i)}\) operates on a group of inputs and produces a group of outputs:

\[ \begin{align}\begin{aligned}\mathbf{X}^{\prime}_0 &= \sum_{i=2}^C \mathbf{X}_i\\\mathbf{X}^{\prime}_i &= f_{\theta(i)} ( \mathbf{X}^{\prime}_{i - 1}, \mathbf{A}) + \mathbf{X}_i\end{aligned}\end{align} \]

for all \(i \in \{ 1, \ldots, C \}\).

Note

For an example of using GroupAddRev, see examples/rev_gnn.py.

Parameters
  • conv (torch.nn.Module or torch.nn.ModuleList]) – A seed GNN. The input and output feature dimensions need to match.

  • split_dim (int optional) – The dimension across which to split groups. (default: -1)

  • num_groups (Optional[int], optional) – The number of groups \(C\). (default: None)

  • disable (bool, optional) – If set to True, will disable the usage of InvertibleFunction and will execute the module without memory savings. (default: False)

  • num_bwd_passes (int, optional) – Number of backward passes to retain a link with the output. After the last backward pass the output is discarded and memory is freed. (default: 1)

property num_groups: int

The number of groups \(C\).

reset_parameters()[source]

Encodings

class PositionalEncoding(out_channels: int, base_freq: float = 0.0001, granularity: float = 1.0)[source]

The positional encoding scheme from “Attention Is All You Need” paper

\[ \begin{align}\begin{aligned}PE(x)_{2 \cdot i} &= \sin(x / 10000^{2 \cdot i / d})\\PE(x)_{2 \cdot i + 1} &= \cos(x / 10000^{2 \cdot i / d})\end{aligned}\end{align} \]

where \(x\) is the position and \(i\) is the dimension.

Parameters
  • out_channels (int) – Size \(d\) of each output sample.

  • base_freq (float, optional) – The base frequency of sinusoidal functions. (default: 1e-4)

  • granularity (float, optional) – The granularity of the positions. If set to smaller value, the encoder will capture more fine-grained changes in positions. (default: 1.0)

forward(x: Tensor) Tensor[source]

Functional

bro

The Batch Representation Orthogonality penalty from the "Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity" paper

gini

The Gini coefficient from the "Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity" paper

bro(x: Tensor, batch: Tensor, p: Union[int, str] = 2) Tensor[source]

The Batch Representation Orthogonality penalty from the “Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity” paper

Computes a regularization for each graph representation in a minibatch according to:

\[\mathcal{L}_{\textrm{BRO}}^\mathrm{graph} = || \mathbf{HH}^T - \mathbf{I}||_p\]

And returns an average over all graphs in the batch.

Parameters
Returns

average BRO penalty in the minibatch

gini(w: Tensor) Tensor[source]

The Gini coefficient from the “Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity” paper

Computes a regularization penalty for each row of a matrix according to:

\[\mathcal{L}_\textrm{Gini}^i = \sum_j^n \sum_{j'}^n \frac{|w_{ij} - w_{ij'}|}{2 (n^2 - n)\bar{w_i}}\]

And returns an average over all rows.

Parameters

w (torch.Tensor) – A two-dimensional tensor.

Returns

The value of the Gini coefficient for this tensor \(\in [0, 1]\)

Model Transformations

class Transformer(module: Module, input_map: Optional[Dict[str, str]] = None, debug: bool = False)[source]

A Transformer executes an FX graph node-by-node, applies transformations to each node, and produces a new torch.nn.Module. It exposes a transform() method that returns the transformed Module. Transformer works entirely symbolically.

Methods in the Transformer class can be overriden to customize the behavior of transformation.

transform()
    +-- Iterate over each node in the graph
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- call_message_passing_module()
        +-- call_global_pooling_module()
        +-- output()
    +-- Erase unused nodes in the graph
    +-- Iterate over each children module
        +-- init_submodule()

In contrast to the torch.fx.Transformer class, the Transformer exposes additional functionality:

  1. It subdivides call_module() into nodes that call a regular torch.nn.Module (call_module()), a MessagePassing module (call_message_passing_module()), or a GlobalPooling module (call_global_pooling_module()).

  2. It allows to customize or initialize new children modules via init_submodule()

  3. It allows to infer whether a node returns node-level or edge-level information via is_edge_level().

Parameters
  • module (torch.nn.Module) – The module to be transformed.

  • input_map (Dict[str, str], optional) – A dictionary holding information about the type of input arguments of module.forward. For example, in case arg is a node-level argument, then input_map['arg'] = 'node', and input_map['arg'] = 'edge' otherwise. In case input_map is not further specified, will try to automatically determine the correct type of input arguments. (default: None)

  • debug (bool, optional) – If set to True, will perform transformation in debug mode. (default: False)

placeholder(node: Node, target: Any, name: str)[source]
get_attr(node: Node, target: Any, name: str)[source]
call_message_passing_module(node: Node, target: Any, name: str)[source]
call_global_pooling_module(node: Node, target: Any, name: str)[source]
call_module(node: Node, target: Any, name: str)[source]
call_method(node: Node, target: Any, name: str)[source]
call_function(node: Node, target: Any, name: str)[source]
output(node: Node, target: Any, name: str)[source]
init_submodule(module: Module, target: str) Module[source]
transform() GraphModule[source]

Transforms self.module and returns a transformed torch.fx.GraphModule.

is_node_level(node: Node) bool[source]
is_edge_level(node: Node) bool[source]
is_graph_level(node: Node) bool[source]
has_node_level_arg(node: Node) bool[source]
has_edge_level_arg(node: Node) bool[source]
has_graph_level_arg(node: Node) bool[source]
replace_all_uses_with(to_replace: Node, replace_with: Node)[source]
to_hetero(module: Module, metadata: Tuple[List[str], List[Tuple[str, str, str]]], aggr: str = 'sum', input_map: Optional[Dict[str, str]] = None, debug: bool = False) GraphModule[source]

Converts a homogeneous GNN model into its heterogeneous equivalent in which node representations are learned for each node type in metadata[0], and messages are exchanged between each edge type in metadata[1], as denoted in the “Modeling Relational Data with Graph Convolutional Networks” paper:

import torch
from torch_geometric.nn import SAGEConv, to_hetero

class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), 32)
        self.conv2 = SAGEConv((32, 32), 32)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        return x

model = GNN()

node_types = ['paper', 'author']
edge_types = [
    ('paper', 'cites', 'paper'),
    ('paper', 'written_by', 'author'),
    ('author', 'writes', 'paper'),
]
metadata = (node_types, edge_types)

model = to_hetero(model, metadata)
model(x_dict, edge_index_dict)

where x_dict and edge_index_dict denote dictionaries that hold node features and edge connectivity information for each node type and edge type, respectively.

The below illustration shows the original computation graph of the homogeneous model on the left, and the newly obtained computation graph of the heterogeneous model on the right:

../_images/to_hetero.svg

Transforming a model via to_hetero().

Here, each MessagePassing instance \(f_{\theta}^{(\ell)}\) is duplicated and stored in a set \(\{ f_{\theta}^{(\ell, r)} : r \in \mathcal{R} \}\) (one instance for each relation in \(\mathcal{R}\)), and message passing in layer \(\ell\) is performed via

\[\mathbf{h}^{(\ell)}_v = \bigoplus_{r \in \mathcal{R}} f_{\theta}^{(\ell, r)} ( \mathbf{h}^{(\ell - 1)}_v, \{ \mathbf{h}^{(\ell - 1)}_w : w \in \mathcal{N}^{(r)}(v) \}),\]

where \(\mathcal{N}^{(r)}(v)\) denotes the neighborhood of \(v \in \mathcal{V}\) under relation \(r \in \mathcal{R}\), and \(\bigoplus\) denotes the aggregation scheme aggr to use for grouping node embeddings generated by different relations ("sum", "mean", "min", "max" or "mul").

Parameters
  • module (torch.nn.Module) – The homogeneous model to transform.

  • metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See torch_geometric.data.HeteroData.metadata() for more information.

  • aggr (string, optional) – The aggregation scheme to use for grouping node embeddings generated by different relations. ("sum", "mean", "min", "max", "mul"). (default: "sum")

  • input_map (Dict[str, str], optional) – A dictionary holding information about the type of input arguments of module.forward. For example, in case arg is a node-level argument, then input_map['arg'] = 'node', and input_map['arg'] = 'edge' otherwise. In case input_map is not further specified, will try to automatically determine the correct type of input arguments. (default: None)

  • debug (bool, optional) – If set to True, will perform transformation in debug mode. (default: False)

to_hetero_with_bases(module: Module, metadata: Tuple[List[str], List[Tuple[str, str, str]]], num_bases: int, in_channels: Optional[Dict[str, int]] = None, input_map: Optional[Dict[str, str]] = None, debug: bool = False) GraphModule[source]

Converts a homogeneous GNN model into its heterogeneous equivalent via the basis-decomposition technique introduced in the “Modeling Relational Data with Graph Convolutional Networks” paper: For this, the heterogeneous graph is mapped to a typed homogeneous graph, in which its feature representations are aligned and grouped to a single representation. All GNN layers inside the model will then perform message passing via basis-decomposition regularization. This transformation is especially useful in highly multi-relational data, such that the number of parameters no longer depend on the number of relations of the input graph:

import torch
from torch_geometric.nn import SAGEConv, to_hetero_with_bases

class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = SAGEConv((16, 16), 32)
        self.conv2 = SAGEConv((32, 32), 32)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        return x

model = GNN()

node_types = ['paper', 'author']
edge_types = [
    ('paper', 'cites', 'paper'),
    ('paper', 'written_by', 'author'),
    ('author', 'writes', 'paper'),
]
metadata = (node_types, edge_types)

model = to_hetero_with_bases(model, metadata, num_bases=3,
                             in_channels={'x': 16})
model(x_dict, edge_index_dict)

where x_dict and edge_index_dict denote dictionaries that hold node features and edge connectivity information for each node type and edge type, respectively. In case in_channels is given for a specific input argument, its heterogeneous feature information is first aligned to the given dimensionality.

The below illustration shows the original computation graph of the homogeneous model on the left, and the newly obtained computation graph of the regularized heterogeneous model on the right:

../_images/to_hetero_with_bases.svg

Transforming a model via to_hetero_with_bases().

Here, each MessagePassing instance \(f_{\theta}^{(\ell)}\) is duplicated num_bases times and stored in a set \(\{ f_{\theta}^{(\ell, b)} : b \in \{ 1, \ldots, B \} \}\) (one instance for each basis in num_bases), and message passing in layer \(\ell\) is performed via

\[\mathbf{h}^{(\ell)}_v = \sum_{r \in \mathcal{R}} \sum_{b=1}^B f_{\theta}^{(\ell, b)} ( \mathbf{h}^{(\ell - 1)}_v, \{ a^{(\ell)}_{r, b} \cdot \mathbf{h}^{(\ell - 1)}_w : w \in \mathcal{N}^{(r)}(v) \}),\]

where \(\mathcal{N}^{(r)}(v)\) denotes the neighborhood of \(v \in \mathcal{V}\) under relation \(r \in \mathcal{R}\). Notably, only the trainable basis coefficients \(a^{(\ell)}_{r, b}\) depend on the relations in \(\mathcal{R}\).

Parameters
  • module (torch.nn.Module) – The homogeneous model to transform.

  • metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See torch_geometric.data.HeteroData.metadata() for more information.

  • num_bases (int) – The number of bases to use.

  • in_channels (Dict[str, int], optional) – A dictionary holding information about the desired input feature dimensionality of input arguments of module.forward. In case in_channels is given for a specific input argument, its heterogeneous feature information is first aligned to the given dimensionality. This allows handling of node and edge features with varying feature dimensionality across different types. (default: None)

  • input_map (Dict[str, str], optional) – A dictionary holding information about the type of input arguments of module.forward. For example, in case arg is a node-level argument, then input_map['arg'] = 'node', and input_map['arg'] = 'edge' otherwise. In case input_map is not further specified, will try to automatically determine the correct type of input arguments. (default: None)

  • debug (bool, optional) – If set to True, will perform transformation in debug mode. (default: False)

Dense Convolutional Layers

DenseGCNConv

See torch_geometric.nn.conv.GCNConv.

DenseGINConv

See torch_geometric.nn.conv.GINConv.

DenseGraphConv

See torch_geometric.nn.conv.GraphConv.

DenseSAGEConv

See torch_geometric.nn.conv.SAGEConv.

class DenseGCNConv(in_channels: int, out_channels: int, improved: bool = False, bias: bool = True)[source]

See torch_geometric.nn.conv.GCNConv.

reset_parameters()[source]
forward(x: Tensor, adj: Tensor, mask: Optional[Tensor] = None, add_loop: bool = True) Tensor[source]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

  • add_loop (bool, optional) – If set to False, the layer will not automatically add self-loops to the adjacency matrices. (default: True)

class DenseGINConv(nn: Module, eps: float = 0.0, train_eps: bool = False)[source]

See torch_geometric.nn.conv.GINConv.

Return type

Tensor

reset_parameters()[source]
forward(x: Tensor, adj: Tensor, mask: Optional[Tensor] = None, add_loop: bool = True) Tensor[source]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

  • add_loop (bool, optional) – If set to False, the layer will not automatically add self-loops to the adjacency matrices. (default: True)

class DenseGraphConv(in_channels: int, out_channels: int, aggr: str = 'add', bias: bool = True)[source]

See torch_geometric.nn.conv.GraphConv.

reset_parameters()[source]
forward(x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) Tensor[source]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

class DenseSAGEConv(in_channels: int, out_channels: int, normalize: bool = False, bias: bool = True)[source]

See torch_geometric.nn.conv.SAGEConv.

Note

DenseSAGEConv expects to work on binary adjacency matrices. If you want to make use of weighted dense adjacency matrices, please use torch_geometric.nn.dense.DenseGraphConv instead.

reset_parameters()[source]
forward(x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) Tensor[source]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

Dense Pooling Layers

dense_diff_pool

The differentiable pooling operator from the "Hierarchical Graph Representation Learning with Differentiable Pooling" paper

dense_mincut_pool

The MinCut pooling operator from the "Spectral Clustering in Graph Neural Networks for Graph Pooling" paper

DMoNPooling

The spectral modularity pooling operator from the "Graph Clustering with Graph Neural Networks" paper

class dense_diff_pool(x: Tensor, adj: Tensor, s: Tensor, mask: Optional[Tensor] = None, normalize: bool = True)[source]

The differentiable pooling operator from the “Hierarchical Graph Representation Learning with Differentiable Pooling” paper

\[ \begin{align}\begin{aligned}\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X}\\\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})\end{aligned}\end{align} \]

based on dense learned assignments \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\). Returns the pooled node feature matrix, the coarsened adjacency matrix and two auxiliary objectives: (1) The link prediction loss

\[\mathcal{L}_{LP} = {\| \mathbf{A} - \mathrm{softmax}(\mathbf{S}) {\mathrm{softmax}(\mathbf{S})}^{\top} \|}_F,\]

and (2) the entropy regularization

\[\mathcal{L}_E = \frac{1}{N} \sum_{n=1}^N H(\mathbf{S}_n).\]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\) with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\).

  • s (Tensor) – Assignment tensor \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\) with number of clusters \(C\). The softmax does not have to be applied beforehand, since it is executed within this method.

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

  • normalize (bool, optional) – If set to False, the link prediction loss is not divided by adj.numel(). (default: True)

Return type

(Tensor, Tensor, Tensor, Tensor)

class dense_mincut_pool(x: Tensor, adj: Tensor, s: Tensor, mask: Optional[Tensor] = None, temp: float = 1.0)[source]

The MinCut pooling operator from the “Spectral Clustering in Graph Neural Networks for Graph Pooling” paper

\[ \begin{align}\begin{aligned}\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X}\\\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})\end{aligned}\end{align} \]

based on dense learned assignments \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\). Returns the pooled node feature matrix, the coarsened and symmetrically normalized adjacency matrix and two auxiliary objectives: (1) The MinCut loss

\[\mathcal{L}_c = - \frac{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{A} \mathbf{S})} {\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{D} \mathbf{S})}\]

where \(\mathbf{D}\) is the degree matrix, and (2) the orthogonality loss

\[\mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}} {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}} \right\|}_F.\]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\) with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (Tensor) – Symmetrically normalized adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\).

  • s (Tensor) – Assignment tensor \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\) with number of clusters \(C\). The softmax does not have to be applied beforehand, since it is executed within this method.

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

  • temp (float) – Temperature parameter for softmax function. (default: 1.0)

Return type

(Tensor, Tensor, Tensor, Tensor)

class DMoNPooling(channels: Union[int, List[int]], k: int, dropout: float = 0.0)[source]

The spectral modularity pooling operator from the “Graph Clustering with Graph Neural Networks” paper

\[ \begin{align}\begin{aligned}\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X}\\\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})\end{aligned}\end{align} \]

based on dense learned assignments \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\). Returns the learned cluster assignment matrix, the pooled node feature matrix, the coarsened symmetrically normalized adjacency matrix, and three auxiliary objectives: (1) The spectral loss

\[\mathcal{L}_s = - \frac{1}{2m} \cdot{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{B} \mathbf{S})}\]

where \(\mathbf{B}\) is the modularity matrix, (2) the orthogonality loss

\[\mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}} {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}} \right\|}_F\]

where \(C\) is the number of clusters, and (3) the cluster loss

\[\mathcal{L}_c = \frac{\sqrt{C}}{n} {\left\|\sum_i\mathbf{C_i}^{\top}\right\|}_F - 1.\]

Note

For an example of using DMoNPooling, see examples/proteins_dmon_pool.py.

Parameters
  • channels (int or List[int]) – Size of each input sample. If given as a list, will construct an MLP based on the given feature sizes.

  • k (int) – The number of clusters.

  • dropout (float, optional) – Dropout probability. (default: 0.0)

reset_parameters()[source]
forward(x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor][source]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\) with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\). Note that the cluster assignment matrix \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\) is being created within this method.

  • adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\).

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

Return type

(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)

DataParallel Layers

class DataParallel(module, device_ids=None, output_device=None, follow_batch=[], exclude_keys=[])[source]

Implements data parallelism at the module level.

This container parallelizes the application of the given module by splitting a list of torch_geometric.data.Data objects and copying them as torch_geometric.data.Batch objects to each device. In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.

The batch size should be larger than the number of GPUs used.

The parallelized module must have its parameters and buffers on device_ids[0].

Note

You need to use the torch_geometric.loader.DataListLoader for this module.

Parameters
  • module (Module) – Module to be parallelized.

  • device_ids (list of int or torch.device) – CUDA devices. (default: all devices)

  • output_device (int or torch.device) – Device location of output. (default: device_ids[0])

  • follow_batch (list or tuple, optional) – Creates assignment batch vectors for each key in the list. (default: [])

  • exclude_keys (list or tuple, optional) – Will exclude each key in the list. (default: [])

forward(data_list)[source]
scatter(data_list, device_ids)[source]