torch_geometric.nn
Contents
- class Sequential(input_args: str, modules: List[Union[Tuple[Callable, str], Callable]])[source]
An extension of the
torch.nn.Sequential
container in order to define a sequential GNN model. Since GNN operators take in multiple input arguments,torch_geometric.nn.Sequential
expects both global input arguments, and function header definitions of individual operators. If omitted, an intermediate module will operate on the output of its preceding module:from torch.nn import Linear, ReLU from torch_geometric.nn import Sequential, GCNConv model = Sequential('x, edge_index', [ (GCNConv(in_channels, 64), 'x, edge_index -> x'), ReLU(inplace=True), (GCNConv(64, 64), 'x, edge_index -> x'), ReLU(inplace=True), Linear(64, out_channels), ])
where
'x, edge_index'
defines the input arguments ofmodel
, and'x, edge_index -> x'
defines the function header, i.e. input arguments and return types, ofGCNConv
.In particular, this also allows to create more sophisticated models, such as utilizing
JumpingKnowledge
:from torch.nn import Linear, ReLU, Dropout from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge from torch_geometric.nn import global_mean_pool model = Sequential('x, edge_index, batch', [ (Dropout(p=0.5), 'x -> x'), (GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'), ReLU(inplace=True), (GCNConv(64, 64), 'x1, edge_index -> x2'), ReLU(inplace=True), (lambda x1, x2: [x1, x2], 'x1, x2 -> xs'), (JumpingKnowledge("cat", 64, num_layers=2), 'xs -> x'), (global_mean_pool, 'x, batch -> x'), Linear(2 * 64, dataset.num_classes), ])
- class Linear(in_channels: int, out_channels: int, bias: bool = True, weight_initializer: Optional[str] = None, bias_initializer: Optional[str] = None)[source]
Applies a linear tranformation to the incoming data
\[\mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b}\]similar to
torch.nn.Linear
. It supports lazy initialization and customizable weight and bias initialization.- Parameters
in_channels (int) – Size of each input sample. Will be initialized lazily in case it is given as
-1
.out_channels (int) – Size of each output sample.
bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)weight_initializer (str, optional) – The initializer for the weight matrix (
"glorot"
,"uniform"
,"kaiming_uniform"
orNone
). If set toNone
, will match default weight initialization oftorch.nn.Linear
. (default:None
)bias_initializer (str, optional) – The initializer for the bias vector (
"zeros"
orNone
). If set toNone
, will match default bias initialization oftorch.nn.Linear
. (default:None
)
- Shapes:
input: features \((*, F_{in})\)
output: features \((*, F_{out})\)
- class HeteroLinear(in_channels: int, out_channels: int, num_types: int, is_sorted: bool = False, **kwargs)[source]
Applies separate linear tranformations to the incoming data according to types
\[\mathbf{x}^{\prime}_{\kappa} = \mathbf{x}_{\kappa} \mathbf{W}^{\top}_{\kappa} + \mathbf{b}_{\kappa}\]for type \(\kappa\). It supports lazy initialization and customizable weight and bias initialization.
- Parameters
in_channels (int) – Size of each input sample. Will be initialized lazily in case it is given as
-1
.out_channels (int) – Size of each output sample.
num_types (int) – The number of types.
is_sorted (bool, optional) – If set to
True
, assumes thattype_vec
is sorted. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default:False
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.Linear
.
- Shapes:
input: features \((*, F_{in})\), type vector \((*)\)
output: features \((*, F_{out})\)
Convolutional Layers
Base class for creating message passing layers of the form |
|
The graph convolutional operator from the "Semi-supervised Classification with Graph Convolutional Networks" paper |
|
The chebyshev spectral graph convolutional operator from the "Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering" paper |
|
The GraphSAGE operator from the "Inductive Representation Learning on Large Graphs" paper |
|
The graph neural network operator from the "Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks" paper |
|
The GravNet operator from the "Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks" paper, where the graph is dynamically constructed using nearest neighbors. |
|
The gated graph convolution operator from the "Gated Graph Sequence Neural Networks" paper |
|
The residual gated graph convolutional operator from the "Residual Gated Graph ConvNets" paper |
|
The graph attentional operator from the "Graph Attention Networks" paper |
|
The fused graph attention operator from the "Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective" paper. |
|
The GATv2 operator from the "How Attentive are Graph Attention Networks?" paper, which fixes the static attention problem of the standard |
|
The graph transformer operator from the "Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification" paper |
|
The graph attentional propagation layer from the "Attention-based Graph Neural Network for Semi-Supervised Learning" paper |
|
The topology adaptive graph convolutional networks operator from the "Topology Adaptive Graph Convolutional Networks" paper |
|
The graph isomorphism operator from the "How Powerful are Graph Neural Networks?" paper |
|
The modified |
|
The ARMA graph convolutional operator from the "Graph Neural Networks with Convolutional ARMA Filters" paper |
|
The simple graph convolutional operator from the "Simplifying Graph Convolutional Networks" paper |
|
The simple spectral graph convolutional operator from the "Simple Spectral Graph Convolution" paper |
|
The approximate personalized propagation of neural predictions layer from the "Predict then Propagate: Graph Neural Networks meet Personalized PageRank" paper |
|
The graph neural network operator from the "Convolutional Networks on Graphs for Learning Molecular Fingerprints" paper |
|
The relational graph convolutional operator from the "Modeling Relational Data with Graph Convolutional Networks" paper |
|
See |
|
The relational graph attentional operator from the "Relational Graph Attention Networks" paper. |
|
The signed graph convolutional operator from the "Signed Graph Convolutional Network" paper |
|
The dynamic neighborhood aggregation operator from the "Just Jump: Towards Dynamic Neighborhood Aggregation in Graph Neural Networks" paper |
|
The PointNet set layer from the "PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation" and "PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" papers |
|
alias of |
|
The gaussian mixture model convolutional operator from the "Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs" paper |
|
The spline-based convolutional operator from the "SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels" paper |
|
The continuous kernel-based convolutional operator from the "Neural Message Passing for Quantum Chemistry" paper. |
|
alias of |
|
The crystal graph convolutional operator from the "Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties" paper |
|
The edge convolutional operator from the "Dynamic Graph CNN for Learning on Point Clouds" paper |
|
The dynamic edge convolutional operator from the "Dynamic Graph CNN for Learning on Point Clouds" paper (see |
|
The convolutional operator on \(\mathcal{X}\)-transformed points from the "PointCNN: Convolution On X-Transformed Points" paper |
|
The PPFNet operator from the "PPFNet: Global Context Aware Local Features for Robust 3D Point Matching" paper |
|
The (translation-invariant) feature-steered convolutional operator from the "FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis" paper |
|
The Point Transformer layer from the "Point Transformer" paper |
|
The hypergraph convolutional operator from the "Hypergraph Convolution and Hypergraph Attention" paper |
|
The local extremum graph neural network operator from the "ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations" paper, which finds the importance of nodes with respect to their neighbors using the difference operator: |
|
The Principal Neighbourhood Aggregation graph convolution operator from the "Principal Neighbourhood Aggregation for Graph Nets" paper |
|
The ClusterGCN graph convolutional operator from the "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" paper |
|
The GENeralized Graph Convolution (GENConv) from the "DeeperGCN: All You Need to Train Deeper GCNs" paper. |
|
The graph convolutional operator with initial residual connections and identity mapping (GCNII) from the "Simple and Deep Graph Convolutional Networks" paper |
|
The path integral based convolutional operator from the "Path Integral Based Convolution and Pooling for Graph Neural Networks" paper |
|
The Weisfeiler Lehman operator from the "A Reduction of a Graph to a Canonical Form and an Algebra Arising During this Reduction" paper, which iteratively refines node colorings: |
|
The Weisfeiler Lehman operator from the "Wasserstein Weisfeiler-Lehman Graph Kernels" paper. |
|
The FiLM graph convolutional operator from the "GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation" paper |
|
The self-supervised graph attentional operator from the "How to Find Your Friendly Neighborhood: Graph Attention Design with Self-Supervision" paper |
|
The Frequency Adaptive Graph Convolution operator from the "Beyond Low-Frequency Information in Graph Convolutional Networks" paper |
|
The Efficient Graph Convolution from the "Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions" paper. |
|
The pathfinder discovery network convolutional operator from the "Pathfinder Discovery Networks for Neural Message Passing" paper |
|
A general GNN layer adapted from the "Design Space for Graph Neural Networks" paper. |
|
The Heterogeneous Graph Transformer (HGT) operator from the "Heterogeneous Graph Transformer" paper. |
|
The heterogeneous edge-enhanced graph attentional operator from the "Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction" paper, which enhances |
|
A generic wrapper for computing graph convolution on heterogeneous graphs. |
|
The Heterogenous Graph Attention Operator from the "Heterogenous Graph Attention Network" paper. |
|
The Light Graph Convolution (LGC) operator from the "LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation" paper |
- class MessagePassing(aggr: Optional[Union[str, List[str], Aggregation]] = 'add', *, aggr_kwargs: Optional[Dict[str, Any]] = None, flow: str = 'source_to_target', node_dim: int = -2, decomposed_layers: int = 1, **kwargs)[source]
Base class for creating message passing layers of the form
\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),\]where \(\square\) denotes a differentiable, permutation invariant function, e.g., sum, mean, min, max or mul, and \(\gamma_{\mathbf{\Theta}}\) and \(\phi_{\mathbf{\Theta}}\) denote differentiable functions such as MLPs. See here for the accompanying tutorial.
- Parameters
aggr (string or list or Aggregation, optional) – The aggregation scheme to use, e.g.,
"add"
,"sum"
"mean"
,"min"
,"max"
or"mul"
. In addition, can be anyAggregation
module (or any string that automatically resolves to it). If given as a list, will make use of multiple aggregations in which different outputs will get concatenated in the last dimension. If set toNone
, theMessagePassing
instantiation is expected to implement its own aggregation logic viaaggregate()
. (default:"add"
)aggr_kwargs (Dict[str, Any], optional) – Arguments passed to the respective aggregation function in case it gets automatically resolved. (default:
None
)flow (string, optional) – The flow direction of message passing (
"source_to_target"
or"target_to_source"
). (default:"source_to_target"
)node_dim (int, optional) – The axis along which to propagate. (default:
-2
)decomposed_layers (int, optional) – The number of feature decomposition layers, as introduced in the “Optimizing Memory Efficiency of Graph Neural Networks on Edge Computing Platforms” paper. Feature decomposition reduces the peak memory usage by slicing the feature dimensions into separated feature decomposition layers during GNN aggregation. This method can accelerate GNN execution on CPU-based platforms (e.g., 2-3x speedup on the
Reddit
dataset) for common GNN models such asGCN
,GraphSAGE
,GIN
, etc. However, this method is not applicable to all GNN operators available, in particular for operators in which message computation can not easily be decomposed, e.g. in attention-based GNNs. The selection of the optimal value ofdecomposed_layers
depends both on the specific graph dataset and available hardware resources. A value of2
is suitable in most cases. Although the peak memory usage is directly associated with the granularity of feature decomposition, the same is not necessarily true for execution speedups. (default:1
)
- propagate(edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None, **kwargs)[source]
The initial call to start propagating messages.
- Parameters
edge_index (Tensor or SparseTensor) – A
torch.LongTensor
, atorch_sparse.SparseTensor
or atorch.sparse.Tensor
that defines the underlying graph connectivity/message passing flow.edge_index
holds the indices of a general (sparse) assignment matrix of shape[N, M]
. Ifedge_index
is of typetorch.LongTensor
, its shape must be defined as[2, num_messages]
, where messages from nodes inedge_index[0]
are sent to nodes inedge_index[1]
(in caseflow="source_to_target"
). Ifedge_index
is of typetorch_sparse.SparseTensor
ortorch.sparse.Tensor
, its sparse indices(row, col)
should relate torow = edge_index[1]
andcol = edge_index[0]
. The major difference between both formats is that we need to input the transposed sparse adjacency matrix intopropagate()
.size (tuple, optional) – The size
(N, M)
of the assignment matrix in caseedge_index
is aLongTensor
. If set toNone
, the size will be automatically inferred and assumed to be quadratic. This argument is ignored in caseedge_index
is atorch_sparse.SparseTensor
or atorch.sparse.Tensor
. (default:None
)**kwargs – Any additional data which is needed to construct and aggregate messages, and to update node embeddings.
- edge_updater(edge_index: Union[Tensor, SparseTensor], **kwargs)[source]
The initial call to compute or update features for each edge in the graph.
- Parameters
edge_index (Tensor or SparseTensor) – A
torch.LongTensor
, atorch_sparse.SparseTensor
or atorch.sparse.Tensor
that defines the underlying graph connectivity/message passing flow. Seepropagate()
for more information.**kwargs – Any additional data which is needed to compute or update features for each edge in the graph.
- message(x_j: Tensor) Tensor [source]
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index
. This function can take any argument as input which was initially passed topropagate()
. Furthermore, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
- aggregate(inputs: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None) Tensor [source]
Aggregates messages from neighbors as \(\square_{j \in \mathcal{N}(i)}\).
Takes in the output of message computation as first argument and any argument which was initially passed to
propagate()
.By default, this function will delegate its call to the underlying
Aggregation
module to reduce messages as specified in__init__()
by theaggr
argument.
- message_and_aggregate(adj_t: Union[SparseTensor, Tensor]) Tensor [source]
Fuses computations of
message()
andaggregate()
into a single function. If applicable, this saves both time and memory since messages do not explicitly need to be materialized. This function will only gets called in case it is implemented and propagation takes place based on atorch_sparse.SparseTensor
or atorch.sparse.Tensor
.
- update(inputs: Tensor) Tensor [source]
Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to
propagate()
.
- edge_update() Tensor [source]
Computes or updates features for each edge in the graph. This function can take any argument as input which was initially passed to
edge_updater()
. Furthermore, tensors passed toedge_updater()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
- register_propagate_forward_pre_hook(hook: Callable) RemovableHandle [source]
Registers a forward pre-hook on the module. The hook will be called every time before
propagate()
is invoked. It should have the following signature:hook(module, inputs) -> None or modified input
The hook can modify the input. Input keyword arguments are passed to the hook as a dictionary in
inputs[-1]
.Returns a
torch.utils.hooks.RemovableHandle
that can be used to remove the added hook by callinghandle.remove()
.
- register_propagate_forward_hook(hook: Callable) RemovableHandle [source]
Registers a forward hook on the module. The hook will be called every time after
propagate()
has computed an output. It should have the following signature:hook(module, inputs, output) -> None or modified output
The hook can modify the output. Input keyword arguments are passed to the hook as a dictionary in
inputs[-1]
.Returns a
torch.utils.hooks.RemovableHandle
that can be used to remove the added hook by callinghandle.remove()
.
- register_message_forward_pre_hook(hook: Callable) RemovableHandle [source]
Registers a forward pre-hook on the module. The hook will be called every time before
message()
is invoked. Seeregister_propagate_forward_pre_hook()
for more information.
- register_message_forward_hook(hook: Callable) RemovableHandle [source]
Registers a forward hook on the module. The hook will be called every time after
message()
has computed an output. Seeregister_propagate_forward_hook()
for more information.
- register_aggregate_forward_pre_hook(hook: Callable) RemovableHandle [source]
Registers a forward pre-hook on the module. The hook will be called every time before
aggregate()
is invoked. Seeregister_propagate_forward_pre_hook()
for more information.
- register_aggregate_forward_hook(hook: Callable) RemovableHandle [source]
Registers a forward hook on the module. The hook will be called every time after
aggregate()
has computed an output. Seeregister_propagate_forward_hook()
for more information.
- register_message_and_aggregate_forward_pre_hook(hook: Callable) RemovableHandle [source]
Registers a forward pre-hook on the module. The hook will be called every time before
message_and_aggregate()
is invoked. Seeregister_propagate_forward_pre_hook()
for more information.
- register_message_and_aggregate_forward_hook(hook: Callable) RemovableHandle [source]
Registers a forward hook on the module. The hook will be called every time after
message_and_aggregate()
has computed an output. Seeregister_propagate_forward_hook()
for more information.
- register_edge_update_forward_pre_hook(hook: Callable) RemovableHandle [source]
Registers a forward pre-hook on the module. The hook will be called every time before
edge_update()
is invoked. Seeregister_propagate_forward_pre_hook()
for more information.
- register_edge_update_forward_hook(hook: Callable) RemovableHandle [source]
Registers a forward hook on the module. The hook will be called every time after
edge_update()
has computed an output. Seeregister_propagate_forward_hook()
for more information.
- jittable(typing: Optional[str] = None) MessagePassing [source]
Analyzes the
MessagePassing
instance and produces a new jittable module.- Parameters
typing (string, optional) – If given, will generate a concrete instance with
forward()
types based ontyping
, e.g.:"(Tensor, Optional[Tensor]) -> Tensor"
.
- class GCNConv(in_channels: int, out_channels: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, **kwargs)[source]
The graph convolutional operator from the “Semi-supervised Classification with Graph Convolutional Networks” paper
\[\mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},\]where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than
1
representing edge weights via the optionaledge_weight
tensor.Its node-wise formulation is given by:
\[\mathbf{x}^{\prime}_i = \mathbf{\Theta}^{\top} \sum_{j \in \mathcal{N}(v) \cup \{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j\]with \(\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}\), where \(e_{j,i}\) denotes the edge weight from source node
j
to target nodei
(default:1.0
)- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
improved (bool, optional) – If set to
True
, the layer computes \(\mathbf{\hat{A}}\) as \(\mathbf{A} + 2\mathbf{I}\). (default:False
)cached (bool, optional) – If set to
True
, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)normalize (bool, optional) – Whether to add self-loops and compute symmetric normalization coefficients on the fly. (default:
True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)
- class ChebConv(in_channels: int, out_channels: int, K: int, normalization: Optional[str] = 'sym', bias: bool = True, **kwargs)[source]
The chebyshev spectral graph convolutional operator from the “Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering” paper
\[\mathbf{X}^{\prime} = \sum_{k=1}^{K} \mathbf{Z}^{(k)} \cdot \mathbf{\Theta}^{(k)}\]where \(\mathbf{Z}^{(k)}\) is computed recursively by
\[ \begin{align}\begin{aligned}\mathbf{Z}^{(1)} &= \mathbf{X}\\\mathbf{Z}^{(2)} &= \mathbf{\hat{L}} \cdot \mathbf{X}\\\mathbf{Z}^{(k)} &= 2 \cdot \mathbf{\hat{L}} \cdot \mathbf{Z}^{(k-1)} - \mathbf{Z}^{(k-2)}\end{aligned}\end{align} \]and \(\mathbf{\hat{L}}\) denotes the scaled and normalized Laplacian \(\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}\).
- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
K (int) – Chebyshev filter size \(K\).
normalization (str, optional) –
The normalization scheme for the graph Laplacian (default:
"sym"
):1.
None
: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)2.
"sym"
: Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)3.
"rw"
: Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)lambda_max
should be atorch.Tensor
of size[num_graphs]
in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-computelambda_max
via thetorch_geometric.transforms.LaplacianLambdaMax
transform.bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional), batch vector \((|\mathcal{V}|)\) (optional), maximum
lambda
value \((|\mathcal{G}|)\) (optional)output: node features \((|\mathcal{V}|, F_{out})\)
- class SAGEConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, aggr: Optional[Union[str, List[str], Aggregation]] = 'mean', normalize: bool = False, root_weight: bool = True, project: bool = False, bias: bool = True, **kwargs)[source]
The GraphSAGE operator from the “Inductive Representation Learning on Large Graphs” paper
\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j\]If
project = True
, then \(\mathbf{x}_j\) will first get projected via\[\mathbf{x}_j \leftarrow \sigma ( \mathbf{W}_3 \mathbf{x}_j + \mathbf{b})\]as described in Eq. (3) of the paper.
- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
aggr (string or Aggregation, optional) – The aggregation scheme to use. Any aggregation of
torch_geometric.nn.aggr
can be used, e.g.,"mean"
,"max"
, or"lstm"
. (default:"mean"
)normalize (bool, optional) – If set to
True
, output features will be \(\ell_2\)-normalized, i.e., \(\frac{\mathbf{x}^{\prime}_i} {\| \mathbf{x}^{\prime}_i \|_2}\). (default:False
)root_weight (bool, optional) – If set to
False
, the layer will not add transformed root node features to the output. (default:True
)project (bool, optional) – If set to
True
, the layer will apply a linear transformation followed by an activation function before aggregation (as described in Eq. (3) of the paper). (default:False
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
inputs: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
outputs: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite
- class GraphConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, aggr: str = 'add', bias: bool = True, **kwargs)[source]
The graph neural network operator from the “Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks” paper
\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{x}_j\]where \(e_{j,i}\) denotes the edge weight from source node
j
to target nodei
(default:1
)- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
aggr (string, optional) – The aggregation scheme to use (
"add"
,"mean"
,"max"
). (default:"add"
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class GravNetConv(in_channels: int, out_channels: int, space_dimensions: int, propagate_dimensions: int, k: int, num_workers: Optional[int] = None, **kwargs)[source]
The GravNet operator from the “Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks” paper, where the graph is dynamically constructed using nearest neighbors. The neighbors are constructed in a learnable low-dimensional projection of the feature space. A second projection of the input feature space is then propagated from the neighbors to each vertex using distance weights that are derived by applying a Gaussian function to the distances.
- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – The number of output channels.
space_dimensions (int) – The dimensionality of the space used to construct the neighbors; referred to as \(S\) in the paper.
propagate_dimensions (int) – The number of features to be propagated between the vertices; referred to as \(F_{\textrm{LR}}\) in the paper.
k (int) – The number of nearest neighbors.
**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))\) if bipartite, batch vector \((|\mathcal{V}|)\) or \(((|\mathcal{V}_s|), (|\mathcal{V}_t|))\) if bipartite (optional)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class GatedGraphConv(out_channels: int, num_layers: int, aggr: str = 'add', bias: bool = True, **kwargs)[source]
The gated graph convolution operator from the “Gated Graph Sequence Neural Networks” paper
\[ \begin{align}\begin{aligned}\mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0}\\\mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)}\\\mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, \mathbf{h}_i^{(l)})\end{aligned}\end{align} \]up to representation \(\mathbf{h}_i^{(L)}\). The number of input channels of \(\mathbf{x}_i\) needs to be less or equal than
out_channels
. \(e_{j,i}\) denotes the edge weight from source nodej
to target nodei
(default:1
)- Parameters
out_channels (int) – Size of each output sample.
num_layers (int) – The sequence length \(L\).
aggr (string, optional) – The aggregation scheme to use (
"add"
,"mean"
,"max"
). (default:"add"
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\)
- class ResGatedGraphConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, act: Optional[Callable] = Sigmoid(), root_weight: bool = True, bias: bool = True, **kwargs)[source]
The residual gated graph convolutional operator from the “Residual Gated Graph ConvNets” paper
\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \eta_{i,j} \odot \mathbf{W}_2 \mathbf{x}_j\]where the gate \(\eta_{i,j}\) is defined as
\[\eta_{i,j} = \sigma(\mathbf{W}_3 \mathbf{x}_i + \mathbf{W}_4 \mathbf{x}_j)\]with \(\sigma\) denoting the sigmoid function.
- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
act (callable, optional) – Gating function \(\sigma\). (default:
torch.nn.Sigmoid()
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)root_weight (bool, optional) – If set to
False
, the layer will not add transformed root node features to the output. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
inputs: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
outputs: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite
- class GATConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, edge_dim: Optional[int] = None, fill_value: Union[float, Tensor, str] = 'mean', bias: bool = True, **kwargs)[source]
The graph attentional operator from the “Graph Attention Networks” paper
\[\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},\]where the attention coefficients \(\alpha_{i,j}\) are computed as
\[\alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] \right)\right)}.\]If the graph has multi-dimensional edge features \(\mathbf{e}_{i,j}\), the attention coefficients \(\alpha_{i,j}\) are computed as
\[\alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j \, \Vert \, \mathbf{\Theta}_{e} \mathbf{e}_{i,j}]\right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k \, \Vert \, \mathbf{\Theta}_{e} \mathbf{e}_{i,k}]\right)\right)}.\]- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
heads (int, optional) – Number of multi-head-attentions. (default:
1
)concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
)negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default:
0.2
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default:
None
)fill_value (float or Tensor or str, optional) – The way to generate edge features of self-loops (in case
edge_dim != None
). If given asfloat
ortorch.Tensor
, edge features of self-loops will be directly given byfill_value
. If given asstr
, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. ("add"
,"mean"
,"min"
,"max"
,"mul"
). (default:"mean"
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, H * F_{out})\) or \(((|\mathcal{V}_t|, H * F_{out})\) if bipartite. If
return_attention_weights=True
, then \(((|\mathcal{V}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))\) or \(((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))\) if bipartite
- forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None, return_attention_weights=None)[source]
- edge_update(alpha_j: Tensor, alpha_i: Optional[Tensor], edge_attr: Optional[Tensor], index: Tensor, ptr: Optional[Tensor], size_i: Optional[int]) Tensor [source]
Computes or updates features for each edge in the graph. This function can take any argument as input which was initially passed to
edge_updater()
. Furthermore, tensors passed toedge_updater()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
- class FusedGATConv(*args, **kwargs)[source]
The fused graph attention operator from the “Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective” paper.
FusedGATConv
is an optimized version ofGATConv
that fuses message passing computation for accelerated exeuction and lower memory footprint.Note
This implementation is based on the
dgNN
package. See here for instructions on how to install.- static to_graph_format(edge_index: Tensor, size: Optional[Tuple[int, int]] = None) Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor] [source]
- forward(x: Tensor, csr: Tuple[Tensor, Tensor], csc: Tuple[Tensor, Tensor], perm: Tensor) Tensor [source]
- Parameters
x (Tensor) – The node features.
csr – ((Tensor, Tensor)): A tuple containing the CSR representation of a graph, given as a tuple of
(rowptr, col)
.csc – ((Tensor, Tensor)): A tuple containing the CSC representation of a graph, given as a tuple of
(row, colptr)
.perm (Tensor) – Permutation tensor to map the CSR representation to the CSC representation.
Note
Use the
torch_geometric.nn.conv.FusedGATConv.to_graph_format()
method to obtain the(csr, csc, perm)
graph format from an existingedge_index
representation.
- class GATv2Conv(in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, edge_dim: Optional[int] = None, fill_value: Union[float, Tensor, str] = 'mean', bias: bool = True, share_weights: bool = False, **kwargs)[source]
The GATv2 operator from the “How Attentive are Graph Attention Networks?” paper, which fixes the static attention problem of the standard
GATConv
layer. Since the linear layers in the standard GAT are applied right after each other, the ranking of attended nodes is unconditioned on the query node. In contrast, inGATv2
, every node can attend to any other node.\[\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},\]where the attention coefficients \(\alpha_{i,j}\) are computed as
\[\alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_j] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_k] \right)\right)}.\]If the graph has multi-dimensional edge features \(\mathbf{e}_{i,j}\), the attention coefficients \(\alpha_{i,j}\) are computed as
\[\alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_j \, \Vert \, \mathbf{e}_{i,j}] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_k \, \Vert \, \mathbf{e}_{i,k}] \right)\right)}.\]- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
heads (int, optional) – Number of multi-head-attentions. (default:
1
)concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
)negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default:
0.2
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default:
None
)fill_value (float or Tensor or str, optional) – The way to generate edge features of self-loops (in case
edge_dim != None
). If given asfloat
ortorch.Tensor
, edge features of self-loops will be directly given byfill_value
. If given asstr
, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. ("add"
,"mean"
,"min"
,"max"
,"mul"
). (default:"mean"
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)share_weights (bool, optional) – If set to
True
, the same matrix will be applied to the source and the target node of every edge. (default:False
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, H * F_{out})\) or \(((|\mathcal{V}_t|, H * F_{out})\) if bipartite. If
return_attention_weights=True
, then \(((|\mathcal{V}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))\) or \(((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))\) if bipartite
- class TransformerConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, beta: bool = False, dropout: float = 0.0, edge_dim: Optional[int] = None, bias: bool = True, root_weight: bool = True, **kwargs)[source]
The graph transformer operator from the “Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification” paper
\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},\]where the attention coefficients \(\alpha_{i,j}\) are computed via multi-head dot product attention:
\[\alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)} {\sqrt{d}} \right)\]- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
heads (int, optional) – Number of multi-head-attentions. (default:
1
)concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
)beta (bool, optional) –
If set, will combine aggregation and skip information via
\[\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}\]with \(\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top} [ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1 \mathbf{x}_i - \mathbf{m}_i ])\) (default:
False
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)edge_dim (int, optional) –
Edge feature dimensionality (in case there are any). Edge features are added to the keys after linear transformation, that is, prior to computing the attention dot product. They are also added to final values after the same linear transformation. The model is:
\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left( \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij} \right),\]where the attention coefficients \(\alpha_{i,j}\) are now computed via:
\[\alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})} {\sqrt{d}} \right)\](default
None
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)root_weight (bool, optional) – If set to
False
, the layer will not add the transformed root node features to the output and the optionbeta
is set toFalse
. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- class AGNNConv(requires_grad: bool = True, add_self_loops: bool = True, **kwargs)[source]
The graph attentional propagation layer from the “Attention-based Graph Neural Network for Semi-Supervised Learning” paper
\[\mathbf{X}^{\prime} = \mathbf{P} \mathbf{X},\]where the propagation matrix \(\mathbf{P}\) is computed as
\[P_{i,j} = \frac{\exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_j))} {\sum_{k \in \mathcal{N}(i)\cup \{ i \}} \exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_k))}\]with trainable parameter \(\beta\).
- Parameters
- Shapes:
input: node features \((|\mathcal{V}|, F)\), edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F)\)
- class TAGConv(in_channels: int, out_channels: int, K: int = 3, bias: bool = True, normalize: bool = True, **kwargs)[source]
The topology adaptive graph convolutional networks operator from the “Topology Adaptive Graph Convolutional Networks” paper
\[\mathbf{X}^{\prime} = \sum_{k=0}^K \left( \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \right)^k \mathbf{X} \mathbf{W}_{k},\]where \(\mathbf{A}\) denotes the adjacency matrix and \(D_{ii} = \sum_{j=0} A_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than
1
representing edge weights via the optionaledge_weight
tensor.- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
K (int, optional) – Number of hops \(K\). (default:
3
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)normalize (bool, optional) – Whether to apply symmetric normalization. (default:
True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node_features \((|\mathcal{V}|, F_{in})\), edge_index \((2, |\mathcal{E}|)\), edge_weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)
- class GINConv(nn: Callable, eps: float = 0.0, train_eps: bool = False, **kwargs)[source]
The graph isomorphism operator from the “How Powerful are Graph Neural Networks?” paper
\[\mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)\]or
\[\mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right),\]here \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. an MLP.
- Parameters
nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features
x
of shape[-1, in_channels]
to shape[-1, out_channels]
, e.g., defined bytorch.nn.Sequential
.eps (float, optional) – (Initial) \(\epsilon\)-value. (default:
0.
)train_eps (bool, optional) – If set to
True
, \(\epsilon\) will be a trainable parameter. (default:False
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class GINEConv(nn: Module, eps: float = 0.0, train_eps: bool = False, edge_dim: Optional[int] = None, **kwargs)[source]
The modified
GINConv
operator from the “Strategies for Pre-training Graph Neural Networks” paper\[\mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathrm{ReLU} ( \mathbf{x}_j + \mathbf{e}_{j,i} ) \right)\]that is able to incorporate edge features \(\mathbf{e}_{j,i}\) into the aggregation procedure.
- Parameters
nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features
x
of shape[-1, in_channels]
to shape[-1, out_channels]
, e.g., defined bytorch.nn.Sequential
.eps (float, optional) – (Initial) \(\epsilon\)-value. (default:
0.
)train_eps (bool, optional) – If set to
True
, \(\epsilon\) will be a trainable parameter. (default:False
)edge_dim (int, optional) – Edge feature dimensionality. If set to
None
, node and edge feature dimensionality is expected to match. Other-wise, edge features are linearly transformed to match node feature dimensionality. (default:None
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class ARMAConv(in_channels: int, out_channels: int, num_stacks: int = 1, num_layers: int = 1, shared_weights: bool = False, act: Optional[Callable] = ReLU(), dropout: float = 0.0, bias: bool = True, **kwargs)[source]
The ARMA graph convolutional operator from the “Graph Neural Networks with Convolutional ARMA Filters” paper
\[\mathbf{X}^{\prime} = \frac{1}{K} \sum_{k=1}^K \mathbf{X}_k^{(T)},\]with \(\mathbf{X}_k^{(T)}\) being recursively defined by
\[\mathbf{X}_k^{(t+1)} = \sigma \left( \mathbf{\hat{L}} \mathbf{X}_k^{(t)} \mathbf{W} + \mathbf{X}^{(0)} \mathbf{V} \right),\]where \(\mathbf{\hat{L}} = \mathbf{I} - \mathbf{L} = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\) denotes the modified Laplacian \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\).
- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample \(\mathbf{x}^{(t+1)}\).
num_stacks (int, optional) – Number of parallel stacks \(K\). (default:
1
).num_layers (int, optional) – Number of layers \(T\). (default:
1
)act (callable, optional) – Activation function \(\sigma\). (default:
torch.nn.ReLU()
)shared_weights (int, optional) – If set to
True
the layers in each stack will share the same parameters. (default:False
)dropout (float, optional) – Dropout probability of the skip connection. (default:
0.
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)
- class SGConv(in_channels: int, out_channels: int, K: int = 1, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]
The simple graph convolutional operator from the “Simplifying Graph Convolutional Networks” paper
\[\mathbf{X}^{\prime} = {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X} \mathbf{\Theta},\]where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than
1
representing edge weights via the optionaledge_weight
tensor.- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
K (int, optional) – Number of hops \(K\). (default:
1
)cached (bool, optional) – If set to
True
, the layer will cache the computation of \({\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X}\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)
- class SSGConv(in_channels: int, out_channels: int, alpha: float, K: int = 1, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]
The simple spectral graph convolutional operator from the “Simple Spectral Graph Convolution” paper
\[\mathbf{X}^{\prime} = \frac{1}{K} \sum_{k=1}^K\left((1-\alpha) {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^k \mathbf{X}+\alpha \mathbf{X}\right) \mathbf{\Theta},\]where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than
1
representing edge weights via the optionaledge_weight
tensor.SSGConv
is an improved operator ofSGConv
by introducing thealpha
parameter to address the oversmoothing issue.- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
alpha (float) – Teleport probability \(\alpha \in [0, 1]\).
K (int, optional) – Number of hops \(K\). (default:
1
)cached (bool, optional) – If set to
True
, the layer will cache the computation of \(\frac{1}{K} \sum_{k=1}^K\left((1-\alpha) {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^k \mathbf{X}+ \alpha \mathbf{X}\right)\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)
- class APPNP(K: int, alpha: float, dropout: float = 0.0, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, **kwargs)[source]
The approximate personalized propagation of neural predictions layer from the “Predict then Propagate: Graph Neural Networks meet Personalized PageRank” paper
\[ \begin{align}\begin{aligned}\mathbf{X}^{(0)} &= \mathbf{X}\\\mathbf{X}^{(k)} &= (1 - \alpha) \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X}^{(k-1)} + \alpha \mathbf{X}^{(0)}\\\mathbf{X}^{\prime} &= \mathbf{X}^{(K)},\end{aligned}\end{align} \]where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than
1
representing edge weights via the optionaledge_weight
tensor.- Parameters
K (int) – Number of iterations \(K\).
alpha (float) – Teleport probability \(\alpha\).
dropout (float, optional) – Dropout probability of edges during training. (default:
0
)cached (bool, optional) – If set to
True
, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)normalize (bool, optional) – Whether to add self-loops and apply symmetric normalization. (default:
True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F)\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F)\)
- class MFConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, max_degree: int = 10, bias=True, **kwargs)[source]
The graph neural network operator from the “Convolutional Networks on Graphs for Learning Molecular Fingerprints” paper
\[\mathbf{x}^{\prime}_i = \mathbf{W}^{(\deg(i))}_1 \mathbf{x}_i + \mathbf{W}^{(\deg(i))}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j\]which trains a distinct weight matrix for each possible vertex degree.
- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
max_degree (int, optional) – The maximum node degree to consider when updating weights (default:
10
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
inputs: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
outputs: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite
- class RGCNConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, num_relations: int, num_bases: Optional[int] = None, num_blocks: Optional[int] = None, aggr: str = 'mean', root_weight: bool = True, is_sorted: bool = False, bias: bool = True, **kwargs)[source]
The relational graph convolutional operator from the “Modeling Relational Data with Graph Convolutional Networks” paper
\[\mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)} \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,\]where \(\mathcal{R}\) denotes the set of relations, i.e. edge types. Edge type needs to be a one-dimensional
torch.long
tensor which stores a relation identifier \(\in \{ 0, \ldots, |\mathcal{R}| - 1\}\) for each edge.Note
This implementation is as memory-efficient as possible by iterating over each individual relation type. Therefore, it may result in low GPU utilization in case the graph has a large number of relations. As an alternative approach,
FastRGCNConv
does not iterate over each individual type, but may consume a large amount of memory to compensate. We advise to check out both implementations to see which one fits your needs.- Parameters
in_channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities. In case no input features are given, this argument should correspond to the number of nodes in your graph.
out_channels (int) – Size of each output sample.
num_relations (int) – Number of relations.
num_bases (int, optional) – If set, this layer will use the basis-decomposition regularization scheme where
num_bases
denotes the number of bases to use. (default:None
)num_blocks (int, optional) – If set, this layer will use the block-diagonal-decomposition regularization scheme where
num_blocks
denotes the number of blocks to use. (default:None
)aggr (string, optional) – The aggregation scheme to use (
"add"
,"mean"
,"max"
). (default:"mean"
)root_weight (bool, optional) – If set to
False
, the layer will not add transformed root node features to the output. (default:True
)is_sorted (bool, optional) – If set to
True
, assumes thatedge_index
is sorted byedge_type
. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default:False
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- forward(x: Union[Tensor, None, Tuple[Optional[Tensor], Tensor]], edge_index: Union[Tensor, SparseTensor], edge_type: Optional[Tensor] = None)[source]
- Parameters
x – The input node features. Can be either a
[num_nodes, in_channels]
node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings). Furthermore,x
can be of typetuple
denoting source and destination node features.edge_index (LongTensor or SparseTensor) – The edge indices.
edge_type – The one-dimensional relation type/index for each edge in
edge_index
. Should be onlyNone
in caseedge_index
is of typetorch_sparse.tensor.SparseTensor
. (default:None
)
- class FastRGCNConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, num_relations: int, num_bases: Optional[int] = None, num_blocks: Optional[int] = None, aggr: str = 'mean', root_weight: bool = True, is_sorted: bool = False, bias: bool = True, **kwargs)[source]
See
RGCNConv
.
- class RGATConv(in_channels: int, out_channels: int, num_relations: int, num_bases: Optional[int] = None, num_blocks: Optional[int] = None, mod: Optional[str] = None, attention_mechanism: str = 'across-relation', attention_mode: str = 'additive-self-attention', heads: int = 1, dim: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, edge_dim: Optional[int] = None, bias: bool = True, **kwargs)[source]
The relational graph attentional operator from the “Relational Graph Attention Networks” paper. Here, attention logits \(\mathbf{a}^{(r)}_{i,j}\) are computed for each relation type \(r\) with the help of both query and key kernels, i.e.
\[\mathbf{q}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot \mathbf{Q}^{(r)} \quad \textrm{and} \quad \mathbf{k}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot \mathbf{K}^{(r)}.\]Two schemes have been proposed to compute attention logits \(\mathbf{a}^{(r)}_{i,j}\) for each relation type \(r\):
Additive attention
\[\mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + \mathbf{k}^{(r)}_j)\]or multiplicative attention
\[\mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j.\]If the graph has multi-dimensional edge features \(\mathbf{e}^{(r)}_{i,j}\), the attention logits \(\mathbf{a}^{(r)}_{i,j}\) for each relation type \(r\) are computed as
\[\mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + \mathbf{k}^{(r)}_j + \mathbf{W}_2^{(r)}\mathbf{e}^{(r)}_{i,j})\]or
\[\mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j \cdot \mathbf{W}_2^{(r)} \mathbf{e}^{(r)}_{i,j},\]respectively. The attention coefficients \(\alpha^{(r)}_{i,j}\) for each relation type \(r\) are then obtained via two different attention mechanisms: The within-relation attention mechanism
\[\alpha^{(r)}_{i,j} = \frac{\exp(\mathbf{a}^{(r)}_{i,j})} {\sum_{k \in \mathcal{N}_r(i)} \exp(\mathbf{a}^{(r)}_{i,k})}\]or the across-relation attention mechanism
\[\alpha^{(r)}_{i,j} = \frac{\exp(\mathbf{a}^{(r)}_{i,j})} {\sum_{r^{\prime} \in \mathcal{R}} \sum_{k \in \mathcal{N}_{r^{\prime}}(i)} \exp(\mathbf{a}^{(r^{\prime})}_{i,k})}\]where \(\mathcal{R}\) denotes the set of relations, i.e. edge types. Edge type needs to be a one-dimensional
torch.long
tensor which stores a relation identifier \(\in \{ 0, \ldots, |\mathcal{R}| - 1\}\) for each edge.To enhance the discriminative power of attention-based GNNs, this layer further implements four different cardinality preservation options as proposed in the “Improving Attention Mechanism in Graph Neural Networks via Cardinality Preservation” paper:
\[ \begin{align}\begin{aligned}\text{additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j + \mathcal{W} \odot \sum_{j \in \mathcal{N}_r(i)} \mathbf{x}^{(r)}_j\\\text{scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= \psi(|\mathcal{N}_r(i)|) \odot \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j\\\text{f-additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= \sum_{j \in \mathcal{N}_r(i)} (\alpha^{(r)}_{i,j} + 1) \cdot \mathbf{x}^{(r)}_j\\\text{f-scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= |\mathcal{N}_r(i)| \odot \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j\end{aligned}\end{align} \]If
attention_mode="additive-self-attention"
andconcat=True
, the layer outputsheads * out_channels
features for each node.If
attention_mode="multiplicative-self-attention"
andconcat=True
, the layer outputsheads * dim * out_channels
features for each node.If
attention_mode="additive-self-attention"
andconcat=False
, the layer outputsout_channels
features for each node.If
attention_mode="multiplicative-self-attention"
andconcat=False
, the layer outputsdim * out_channels
features for each node.
Please make sure to set the
in_channels
argument of the next layer accordingly if more than one instance of this layer is used.Note
For an example of using
RGATConv
, see examples/rgat.py.- Parameters
in_channels (int) – Size of each input sample.
out_channels (int) – Size of each output sample.
num_relations (int) – Number of relations.
num_bases (int, optional) – If set, this layer will use the basis-decomposition regularization scheme where
num_bases
denotes the number of bases to use. (default:None
)num_blocks (int, optional) – If set, this layer will use the block-diagonal-decomposition regularization scheme where
num_blocks
denotes the number of blocks to use. (default:None
)mod (str, optional) – The cardinality preservation option to use. (
"additive"
,"scaled"
,"f-additive"
,"f-scaled"
,None
). (default:None
)attention_mechanism (str, optional) – The attention mechanism to use (
"within-relation"
,"across-relation"
). (default:"across-relation"
)attention_mode (str, optional) – The mode to calculate attention logits. (
"additive-self-attention"
,"multiplicative-self-attention"
). (default:"additive-self-attention"
)heads (int, optional) – Number of multi-head-attentions. (default:
1
)dim (int) – Number of dimensions for query and key kernels. (default:
1
)concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
)negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default:
0.2
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default:
None
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_type: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None, return_attention_weights=None)[source]
- Parameters
x (Tensor) – The input node features. Can be either a
[num_nodes, in_channels]
node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings).edge_index (LongTensor or SparseTensor) – The edge indices.
edge_type – The one-dimensional relation type/index for each edge in
edge_index
. Should be onlyNone
in caseedge_index
is of typetorch_sparse.tensor.SparseTensor
. (default:None
)edge_attr (Tensor, optional) – Edge feature matrix. (default:
None
)return_attention_weights (bool, optional) – If set to
True
, will additionally return the tuple(edge_index, attention_weights)
, holding the computed attention weights for each edge. (default:None
)
- class SignedConv(in_channels: int, out_channels: int, first_aggr: bool, bias: bool = True, **kwargs)[source]
The signed graph convolutional operator from the “Signed Graph Convolutional Network” paper
\[ \begin{align}\begin{aligned}\mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w , \mathbf{x}_v \right]\\\mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{neg})} \left[ \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w , \mathbf{x}_v \right]\end{aligned}\end{align} \]if
first_aggr
is set toTrue
, and\[ \begin{align}\begin{aligned}\mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{pos})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{neg})}, \mathbf{x}_v^{(\textrm{pos})} \right]\\\mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{neg})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{pos})}, \mathbf{x}_v^{(\textrm{neg})} \right]\end{aligned}\end{align} \]otherwise. In case
first_aggr
isFalse
, the layer expectsx
to be a tensor wherex[:, :in_channels]
denotes the positive node features \(\mathbf{X}^{(\textrm{pos})}\) andx[:, in_channels:]
denotes the negative node features \(\mathbf{X}^{(\textrm{neg})}\).- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
first_aggr (bool) – Denotes which aggregation formula to use.
bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))\) if bipartite, positive edge indices \((2, |\mathcal{E}^{(+)}|)\), negative edge indices \((2, |\mathcal{E}^{(-)}|)\)
outputs: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite
- class DNAConv(channels: int, heads: int = 1, groups: int = 1, dropout: float = 0.0, cached: bool = False, normalize: bool = True, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]
The dynamic neighborhood aggregation operator from the “Just Jump: Towards Dynamic Neighborhood Aggregation in Graph Neural Networks” paper
\[\mathbf{x}_v^{(t)} = h_{\mathbf{\Theta}}^{(t)} \left( \mathbf{x}_{v \leftarrow v}^{(t)}, \left\{ \mathbf{x}_{v \leftarrow w}^{(t)} : w \in \mathcal{N}(v) \right\} \right)\]based on (multi-head) dot-product attention
\[\mathbf{x}_{v \leftarrow w}^{(t)} = \textrm{Attention} \left( \mathbf{x}^{(t-1)}_v \, \mathbf{\Theta}_Q^{(t)}, [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_K^{(t)}, \, [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_V^{(t)} \right)\]with \(\mathbf{\Theta}_Q^{(t)}, \mathbf{\Theta}_K^{(t)}, \mathbf{\Theta}_V^{(t)}\) denoting (grouped) projection matrices for query, key and value information, respectively. \(h^{(t)}_{\mathbf{\Theta}}\) is implemented as a non-trainable version of
torch_geometric.nn.conv.GCNConv
.Note
In contrast to other layers, this operator expects node features as shape
[num_nodes, num_layers, channels]
.- Parameters
channels (int) – Size of each input/output sample.
heads (int, optional) – Number of multi-head-attentions. (default:
1
)groups (int, optional) – Number of groups to use for all linear projections. (default:
1
)dropout (float, optional) – Dropout probability of attention coefficients. (default:
0.
)cached (bool, optional) – If set to
True
, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
)normalize (bool, optional) – Whether to add self-loops and apply symmetric normalization. (default:
True
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, L, F)\) where \(L\) is the number of layers, edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F)\)
- class PointNetConv(local_nn: Optional[Callable] = None, global_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs)[source]
The PointNet set layer from the “PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation” and “PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space” papers
\[\mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \mathbf{p}_j - \mathbf{p}_i) \right),\]where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote neural networks, i.e. MLPs, and \(\mathbf{P} \in \mathbb{R}^{N \times D}\) defines the position of each point.
- Parameters
local_nn (torch.nn.Module, optional) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features
x
and relative spatial coordinatespos_j - pos_i
of shape[-1, in_channels + num_dimensions]
to shape[-1, out_channels]
, e.g., defined bytorch.nn.Sequential
. (default:None
)global_nn (torch.nn.Module, optional) – A neural network \(\gamma_{\mathbf{\Theta}}\) that maps aggregated node features of shape
[-1, out_channels]
to shape[-1, final_out_channels]
, e.g., defined bytorch.nn.Sequential
. (default:None
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, positions \((|\mathcal{V}|, 3)\) or \(((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- PointConv
alias of
PointNetConv
- class GMMConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, dim: int, kernel_size: int, separate_gaussians: bool = False, aggr: str = 'mean', root_weight: bool = True, bias: bool = True, **kwargs)[source]
The gaussian mixture model convolutional operator from the “Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs” paper
\[\mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \frac{1}{K} \sum_{k=1}^K \mathbf{w}_k(\mathbf{e}_{i,j}) \odot \mathbf{\Theta}_k \mathbf{x}_j,\]where
\[\mathbf{w}_k(\mathbf{e}) = \exp \left( -\frac{1}{2} {\left( \mathbf{e} - \mathbf{\mu}_k \right)}^{\top} \Sigma_k^{-1} \left( \mathbf{e} - \mathbf{\mu}_k \right) \right)\]denotes a weighting function based on trainable mean vector \(\mathbf{\mu}_k\) and diagonal covariance matrix \(\mathbf{\Sigma}_k\).
Note
The edge attribute \(\mathbf{e}_{ij}\) is usually given by \(\mathbf{e}_{ij} = \mathbf{p}_j - \mathbf{p}_i\), where \(\mathbf{p}_i\) denotes the position of node \(i\) (see
torch_geometric.transform.Cartesian
).- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
dim (int) – Pseudo-coordinate dimensionality.
kernel_size (int) – Number of kernels \(K\).
separate_gaussians (bool, optional) – If set to
True
, will learn separate GMMs for every pair of input and output channel, inspired by traditional CNNs. (default:False
)aggr (string, optional) – The aggregation operator to use (
"add"
,"mean"
,"max"
). (default:"mean"
)root_weight (bool, optional) – If set to
False
, the layer will not add transformed root node features to the output. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class SplineConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, dim: int, kernel_size: Union[int, List[int]], is_open_spline: bool = True, degree: int = 1, aggr: str = 'mean', root_weight: bool = True, bias: bool = True, **kwargs)[source]
The spline-based convolutional operator from the “SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels” paper
\[\mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),\]where \(h_{\mathbf{\Theta}}\) denotes a kernel function defined over the weighted B-Spline tensor product basis.
Note
Pseudo-coordinates must lay in the fixed interval \([0, 1]\) for this method to work as intended.
- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
dim (int) – Pseudo-coordinate dimensionality.
is_open_spline (bool or [bool], optional) – If set to
False
, the operator will use a closed B-spline basis in this dimension. (defaultTrue
)degree (int, optional) – B-spline basis degrees. (default:
1
)aggr (string, optional) – The aggregation operator to use (
"add"
,"mean"
,"max"
). (default:"mean"
)root_weight (bool, optional) – If set to
False
, the layer will not add transformed root node features to the output. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- class NNConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, nn: Callable, aggr: str = 'add', root_weight: bool = True, bias: bool = True, **kwargs)[source]
The continuous kernel-based convolutional operator from the “Neural Message Passing for Quantum Chemistry” paper. This convolution is also known as the edge-conditioned convolution from the “Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs” paper (see
torch_geometric.nn.conv.ECConv
for an alias):\[\mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),\]where \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. a MLP.
- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps edge features
edge_attr
of shape[-1, num_edge_features]
to shape[-1, in_channels * out_channels]
, e.g., defined bytorch.nn.Sequential
.aggr (string, optional) – The aggregation scheme to use (
"add"
,"mean"
,"max"
). (default:"add"
)root_weight (bool, optional) – If set to
False
, the layer will not add the transformed root node features to the output. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class CGConv(channels: Union[int, Tuple[int, int]], dim: int = 0, aggr: str = 'add', batch_norm: bool = False, bias: bool = True, **kwargs)[source]
The crystal graph convolutional operator from the “Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties” paper
\[\mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \sigma \left( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right) \odot g \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s \right)\]where \(\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{i,j} ]\) denotes the concatenation of central node features, neighboring node features and edge features. In addition, \(\sigma\) and \(g\) denote the sigmoid and softplus functions, respectively.
- Parameters
channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities.
dim (int, optional) – Edge feature dimensionality. (default:
0
)aggr (string, optional) – The aggregation operator to use (
"add"
,"mean"
,"max"
). (default:"add"
)batch_norm (bool, optional) – If set to
True
, will make use of batch normalization. (default:False
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F)\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F)\) or \((|\mathcal{V_t}|, F_{t})\) if bipartite
- class EdgeConv(nn: Callable, aggr: str = 'max', **kwargs)[source]
The edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}}(\mathbf{x}_i \, \Vert \, \mathbf{x}_j - \mathbf{x}_i),\]where \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. a MLP.
- Parameters
nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps pair-wise concatenated node features
x
of shape[-1, 2 * in_channels]
to shape[-1, out_channels]
, e.g., defined bytorch.nn.Sequential
.aggr (string, optional) – The aggregation scheme to use (
"add"
,"mean"
,"max"
). (default:"max"
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class DynamicEdgeConv(nn: Callable, k: int, aggr: str = 'max', num_workers: int = 1, **kwargs)[source]
The dynamic edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper (see
torch_geometric.nn.conv.EdgeConv
), where the graph is dynamically constructed using nearest neighbors in the feature space.- Parameters
nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps pair-wise concatenated node features
x
of shape :obj:`[-1, 2 * in_channels] to shape[-1, out_channels]
, e.g. defined bytorch.nn.Sequential
.k (int) – Number of nearest neighbors.
aggr (string) – The aggregation operator to use (
"add"
,"mean"
,"max"
). (default:"max"
)num_workers (int) – Number of workers to use for k-NN computation. Has no effect in case
batch
is notNone
, or the input lies on the GPU. (default:1
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))\) if bipartite, batch vector \((|\mathcal{V}|)\) or \(((|\mathcal{V}|), (|\mathcal{V}|))\) if bipartite (optional)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class XConv(in_channels: int, out_channels: int, dim: int, kernel_size: int, hidden_channels: Optional[int] = None, dilation: int = 1, bias: bool = True, num_workers: int = 1)[source]
The convolutional operator on \(\mathcal{X}\)-transformed points from the “PointCNN: Convolution On X-Transformed Points” paper
\[\mathbf{x}^{\prime}_i = \mathrm{Conv}\left(\mathbf{K}, \gamma_{\mathbf{\Theta}}(\mathbf{P}_i - \mathbf{p}_i) \times \left( h_\mathbf{\Theta}(\mathbf{P}_i - \mathbf{p}_i) \, \Vert \, \mathbf{x}_i \right) \right),\]where \(\mathbf{K}\) and \(\mathbf{P}_i\) denote the trainable filter and neighboring point positions of \(\mathbf{x}_i\), respectively. \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) describe neural networks, i.e. MLPs, where \(h_{\mathbf{\Theta}}\) individually lifts each point into a higher-dimensional space, and \(\gamma_{\mathbf{\Theta}}\) computes the \(\mathcal{X}\)- transformation matrix based on all points in a neighborhood.
- Parameters
in_channels (int) – Size of each input sample.
out_channels (int) – Size of each output sample.
dim (int) – Point cloud dimensionality.
kernel_size (int) – Size of the convolving kernel, i.e. number of neighbors including self-loops.
hidden_channels (int, optional) – Output size of \(h_{\mathbf{\Theta}}\), i.e. dimensionality of lifted points. If set to
None
, will be automatically set toin_channels / 4
. (default:None
)dilation (int, optional) – The factor by which the neighborhood is extended, from which
kernel_size
neighbors are then uniformly sampled. Can be interpreted as the dilation rate of classical convolutional operators. (default:1
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)num_workers (int) – Number of workers to use for k-NN computation. Has no effect in case
batch
is notNone
, or the input lies on the GPU. (default:1
)
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), positions \((|\mathcal{V}|, D)\), batch vector \((|\mathcal{V}|)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)
- class PPFConv(local_nn: Optional[Callable] = None, global_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs)[source]
The PPFNet operator from the “PPFNet: Global Context Aware Local Features for Robust 3D Point Matching” paper
\[\mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \| \mathbf{d_{j,i}} \|, \angle(\mathbf{n}_i, \mathbf{d_{j,i}}), \angle(\mathbf{n}_j, \mathbf{d_{j,i}}), \angle(\mathbf{n}_i, \mathbf{n}_j) \right)\]where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote neural networks, .i.e. MLPs, which takes in node features and
torch_geometric.transforms.PointPairFeatures
.- Parameters
local_nn (torch.nn.Module, optional) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features
x
and relative spatial coordinatespos_j - pos_i
of shape[-1, in_channels + num_dimensions]
to shape[-1, out_channels]
, e.g., defined bytorch.nn.Sequential
. (default:None
)global_nn (torch.nn.Module, optional) – A neural network \(\gamma_{\mathbf{\Theta}}\) that maps aggregated node features of shape
[-1, out_channels]
to shape[-1, final_out_channels]
, e.g., defined bytorch.nn.Sequential
. (default:None
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, positions \((|\mathcal{V}|, 3)\) or \(((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))\) if bipartite, point normals \((|\mathcal{V}, 3)\) or \(((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class FeaStConv(in_channels: int, out_channels: int, heads: int = 1, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]
The (translation-invariant) feature-steered convolutional operator from the “FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis” paper
\[\mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \sum_{h=1}^H q_h(\mathbf{x}_i, \mathbf{x}_j) \mathbf{W}_h \mathbf{x}_j\]with \(q_h(\mathbf{x}_i, \mathbf{x}_j) = \mathrm{softmax}_j (\mathbf{u}_h^{\top} (\mathbf{x}_j - \mathbf{x}_i) + c_h)\), where \(H\) denotes the number of attention heads, and \(\mathbf{W}_h\), \(\mathbf{u}_h\) and \(c_h\) are trainable parameters.
- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
heads (int, optional) – Number of attention heads \(H\). (default:
1
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite
- class PointTransformerConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, pos_nn: Optional[Callable] = None, attn_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs)[source]
The Point Transformer layer from the “Point Transformer” paper
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j} \left(\mathbf{W}_3 \mathbf{x}_j + \delta_{ij} \right),\]where the attention coefficients \(\alpha_{i,j}\) and positional embedding \(\delta_{ij}\) are computed as
\[\alpha_{i,j}= \textrm{softmax} \left( \gamma_\mathbf{\Theta} (\mathbf{W}_1 \mathbf{x}_i - \mathbf{W}_2 \mathbf{x}_j + \delta_{i,j}) \right)\]and
\[\delta_{i,j}= h_{\mathbf{\Theta}}(\mathbf{p}_i - \mathbf{p}_j),\]with \(\gamma_\mathbf{\Theta}\) and \(h_\mathbf{\Theta}\) denoting neural networks, i.e. MLPs, and \(\mathbf{P} \in \mathbb{R}^{N \times D}\) defines the position of each point.
- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
pos_nn (torch.nn.Module, optional) – A neural network \(h_\mathbf{\Theta}\) which maps relative spatial coordinates
pos_j - pos_i
of shape[-1, 3]
to shape[-1, out_channels]
. Will default to atorch.nn.Linear
transformation if not further specified. (default:None
)attn_nn (torch.nn.Module, optional) – A neural network \(\gamma_\mathbf{\Theta}\) which maps transformed node features of shape
[-1, out_channels]
to shape[-1, out_channels]
. (default:None
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, positions \((|\mathcal{V}|, 3)\) or \(((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class HypergraphConv(in_channels, out_channels, use_attention=False, heads=1, concat=True, negative_slope=0.2, dropout=0, bias=True, **kwargs)[source]
The hypergraph convolutional operator from the “Hypergraph Convolution and Hypergraph Attention” paper
\[\mathbf{X}^{\prime} = \mathbf{D}^{-1} \mathbf{H} \mathbf{W} \mathbf{B}^{-1} \mathbf{H}^{\top} \mathbf{X} \mathbf{\Theta}\]where \(\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}\) is the incidence matrix, \(\mathbf{W} \in \mathbb{R}^M\) is the diagonal hyperedge weight matrix, and \(\mathbf{D}\) and \(\mathbf{B}\) are the corresponding degree matrices.
For example, in the hypergraph scenario \(\mathcal{G} = (\mathcal{V}, \mathcal{E})\) with \(\mathcal{V} = \{ 0, 1, 2, 3 \}\) and \(\mathcal{E} = \{ \{ 0, 1, 2 \}, \{ 1, 2, 3 \} \}\), the
hyperedge_index
is represented as:hyperedge_index = torch.tensor([ [0, 1, 2, 1, 2, 3], [0, 0, 0, 1, 1, 1], ])
- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
use_attention (bool, optional) – If set to
True
, attention will be added to this layer. (default:False
)heads (int, optional) – Number of multi-head-attentions. (default:
1
)concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
)negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default:
0.2
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), hyperedge indices \((|\mathcal{V}|, |\mathcal{E}|)\), hyperedge weights \((|\mathcal{E}|)\) (optional) hyperedge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)
- forward(x: Tensor, hyperedge_index: Tensor, hyperedge_weight: Optional[Tensor] = None, hyperedge_attr: Optional[Tensor] = None) Tensor [source]
- Parameters
x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
hyperedge_index (LongTensor) – The hyperedge indices, i.e. the sparse incidence matrix \(\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}\) mapping from nodes to edges.
hyperedge_weight (Tensor, optional) – Hyperedge weights \(\mathbf{W} \in \mathbb{R}^M\). (default:
None
)hyperedge_attr (Tensor, optional) – Hyperedge feature matrix in \(\mathbb{R}^{M \times F}\). These features only need to get passed in case
use_attention=True
. (default:None
)
- class LEConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, bias: bool = True, **kwargs)[source]
The local extremum graph neural network operator from the “ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations” paper, which finds the importance of nodes with respect to their neighbors using the difference operator:
\[\mathbf{x}^{\prime}_i = \mathbf{x}_i \cdot \mathbf{\Theta}_1 + \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot (\mathbf{\Theta}_2 \mathbf{x}_i - \mathbf{\Theta}_3 \mathbf{x}_j)\]where \(e_{j,i}\) denotes the edge weight from source node
j
to target nodei
(default:1
)- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
).**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class PNAConv(in_channels: int, out_channels: int, aggregators: List[str], scalers: List[str], deg: Tensor, edge_dim: Optional[int] = None, towers: int = 1, pre_layers: int = 1, post_layers: int = 1, divide_input: bool = False, act: Optional[Union[str, Callable]] = 'relu', act_kwargs: Optional[Dict[str, Any]] = None, train_norm: bool = False, **kwargs)[source]
The Principal Neighbourhood Aggregation graph convolution operator from the “Principal Neighbourhood Aggregation for Graph Nets” paper
\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \underset{j \in \mathcal{N}(i)}{\bigoplus} h_{\mathbf{\Theta}} \left( \mathbf{x}_i, \mathbf{x}_j \right) \right)\]with
\[\begin{split}\bigoplus = \underbrace{\begin{bmatrix} 1 \\ S(\mathbf{D}, \alpha=1) \\ S(\mathbf{D}, \alpha=-1) \end{bmatrix} }_{\text{scalers}} \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min \end{bmatrix}}_{\text{aggregators}},\end{split}\]where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote MLPs.
Note
For an example of using
PNAConv
, see examples/pna.py.- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
aggregators (list of str) – Set of aggregation function identifiers, namely
"sum"
,"mean"
,"min"
,"max"
,"var"
and"std"
.scalers (list of str) – Set of scaling function identifiers, namely
"identity"
,"amplification"
,"attenuation"
,"linear"
and"inverse_linear"
.deg (Tensor) – Histogram of in-degrees of nodes in the training set, used by scalers to normalize.
edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default
None
)towers (int, optional) – Number of towers (default:
1
).pre_layers (int, optional) – Number of transformation layers before aggregation (default:
1
).post_layers (int, optional) – Number of transformation layers after aggregation (default:
1
).divide_input (bool, optional) – Whether the input features should be split between towers or not (default:
False
).act (str or Callable, optional) – Pre- and post-layer activation function to use. (default:
"relu"
)act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by
act
. (default:None
)train_norm (bool, optional) – are trainable. (default:
False
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)
- class ClusterGCNConv(in_channels: int, out_channels: int, diag_lambda: float = 0.0, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]
The ClusterGCN graph convolutional operator from the “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks” paper
\[\mathbf{X}^{\prime} = \left( \mathbf{\hat{A}} + \lambda \cdot \textrm{diag}(\mathbf{\hat{A}}) \right) \mathbf{X} \mathbf{W}_1 + \mathbf{X} \mathbf{W}_2\]where \(\mathbf{\hat{A}} = {(\mathbf{D} + \mathbf{I})}^{-1}(\mathbf{A} + \mathbf{I})\).
- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
diag_lambda (float, optional) – Diagonal enhancement value \(\lambda\). (default:
0.
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\)
- class GENConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, aggr: Optional[Union[str, List[str], Aggregation]] = 'softmax', t: float = 1.0, learn_t: bool = False, p: float = 1.0, learn_p: bool = False, msg_norm: bool = False, learn_msg_scale: bool = False, norm: str = 'batch', num_layers: int = 2, expansion: int = 2, eps: float = 1e-07, bias: bool = False, edge_dim: Optional[int] = None, **kwargs)[source]
The GENeralized Graph Convolution (GENConv) from the “DeeperGCN: All You Need to Train Deeper GCNs” paper. Supports SoftMax & PowerMean aggregation. The message construction is:
\[\mathbf{x}_i^{\prime} = \mathrm{MLP} \left( \mathbf{x}_i + \mathrm{AGG} \left( \left\{ \mathrm{ReLU} \left( \mathbf{x}_j + \mathbf{e_{ji}} \right) +\epsilon : j \in \mathcal{N}(i) \right\} \right) \right)\]Note
For an example of using
GENConv
, see examples/ogbn_proteins_deepgcn.py.- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
aggr (string or Aggregation, optional) – The aggregation scheme to use. Any aggregation of
torch_geometric.nn.aggr
can be used, ("softmax"
,"powermean"
,"add"
,"mean"
,max
). (default:"softmax"
)t (float, optional) – Initial inverse temperature for softmax aggregation. (default:
1.0
)learn_t (bool, optional) – If set to
True
, will learn the valuet
for softmax aggregation dynamically. (default:False
)p (float, optional) – Initial power for power mean aggregation. (default:
1.0
)learn_p (bool, optional) – If set to
True
, will learn the valuep
for power mean aggregation dynamically. (default:False
)msg_norm (bool, optional) – If set to
True
, will use message normalization. (default:False
)learn_msg_scale (bool, optional) – If set to
True
, will learn the scaling factor of message normalization. (default:False
)norm (str, optional) – Norm layer of MLP layers (
"batch"
,"layer"
,"instance"
) (default:batch
)num_layers (int, optional) – The number of MLP layers. (default:
2
)expansion (int, optional) – The expansion factor of hidden channels in MLP layers. (default:
2
)eps (float, optional) – The epsilon value of the message construction function. (default:
1e-7
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)edge_dim (int, optional) – Edge feature dimensionality. If set to
None
, Edge feature dimensionality is expected to match the out_channels. Other-wise, edge features are linearly transformed to match out_channels of node feature dimensionality. (default:None
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.GenMessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge attributes \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class GCN2Conv(channels: int, alpha: float, theta: Optional[float] = None, layer: Optional[int] = None, shared_weights: bool = True, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, **kwargs)[source]
The graph convolutional operator with initial residual connections and identity mapping (GCNII) from the “Simple and Deep Graph Convolutional Networks” paper
\[\mathbf{X}^{\prime} = \left( (1 - \alpha) \mathbf{\hat{P}}\mathbf{X} + \alpha \mathbf{X^{(0)}}\right) \left( (1 - \beta) \mathbf{I} + \beta \mathbf{\Theta} \right)\]with \(\mathbf{\hat{P}} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\), where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix, and \(\mathbf{X}^{(0)}\) being the initial feature representation. Here, \(\alpha\) models the strength of the initial residual connection, while \(\beta\) models the strength of the identity mapping. The adjacency matrix can include other values than
1
representing edge weights via the optionaledge_weight
tensor.- Parameters
channels (int) – Size of each input and output sample.
alpha (float) – The strength of the initial residual connection \(\alpha\).
theta (float, optional) – The hyperparameter \(\theta\) to compute the strength of the identity mapping \(\beta = \log \left( \frac{\theta}{\ell} + 1 \right)\). (default:
None
)layer (int, optional) – The layer \(\ell\) in which this module is executed. (default:
None
)shared_weights (bool, optional) – If set to
False
, will use different weight matrices for the smoothed representation and the initial residual (“GCNII*”). (default:True
)cached (bool, optional) – If set to
True
, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
)normalize (bool, optional) – Whether to add self-loops and apply symmetric normalization. (default:
True
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F)\), initial node features \((|\mathcal{V}|, F)\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F)\)
- class PANConv(in_channels: int, out_channels: int, filter_size: int, **kwargs)[source]
The path integral based convolutional operator from the “Path Integral Based Convolution and Pooling for Graph Neural Networks” paper
\[\mathbf{X}^{\prime} = \mathbf{M} \mathbf{X} \mathbf{W}\]where \(\mathbf{M}\) denotes the normalized and learned maximal entropy transition (MET) matrix that includes neighbors up to
filter_size
hops:\[\mathbf{M} = \mathbf{Z}^{-1/2} \sum_{n=0}^L e^{-\frac{E(n)}{T}} \mathbf{A}^n \mathbf{Z}^{-1/2}\]- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
filter_size (int) – The filter size \(L\).
**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\),
output: node features \((|\mathcal{V}|, F_{out})\)
- class WLConv[source]
The Weisfeiler Lehman operator from the “A Reduction of a Graph to a Canonical Form and an Algebra Arising During this Reduction” paper, which iteratively refines node colorings:
\[\mathbf{x}^{\prime}_i = \textrm{hash} \left( \mathbf{x}_i, \{ \mathbf{x}_j \colon j \in \mathcal{N}(i) \} \right)\]- Shapes:
input: node coloring \((|\mathcal{V}|, F_{in})\) (one-hot encodings) or \((|\mathcal{V}|)\) (integer-based), edge indices \((2, |\mathcal{E}|)\)
output: node coloring \((|\mathcal{V}|)\) (integer-based)
- class WLConvContinuous(**kwargs)[source]
The Weisfeiler Lehman operator from the “Wasserstein Weisfeiler-Lehman Graph Kernels” paper. Refinement is done though a degree-scaled mean aggregation and works on nodes with continuous attributes:
\[\mathbf{x}^{\prime}_i = \frac{1}{2}\big(\mathbf{x}_i + \frac{1}{\textrm{deg}(i)} \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{x}_j \big)\]where \(e_{j,i}\) denotes the edge weight from source node
j
to target nodei
(default:1
)- Parameters
**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F)\) or \(((|\mathcal{V_s}|, F), (|\mathcal{V_t}|, F))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F)\) or \((|\mathcal{V}_t|, F)\) if bipartite
- class FiLMConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, num_relations: int = 1, nn: Optional[Callable] = None, act: Optional[Callable] = ReLU(), aggr: str = 'mean', **kwargs)[source]
The FiLM graph convolutional operator from the “GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation” paper
\[\mathbf{x}^{\prime}_i = \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}(i)} \sigma \left( \boldsymbol{\gamma}_{r,i} \odot \mathbf{W}_r \mathbf{x}_j + \boldsymbol{\beta}_{r,i} \right)\]where \(\boldsymbol{\beta}_{r,i}, \boldsymbol{\gamma}_{r,i} = g(\mathbf{x}_i)\) with \(g\) being a single linear layer by default. Self-loops are automatically added to the input graph and represented as its own relation type.
Note
For an example of using FiLM, see examples/gcn.py.
- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
num_relations (int, optional) – Number of relations. (default:
1
)nn (torch.nn.Module, optional) – The neural network \(g\) that maps node features
x_i
of shape[-1, in_channels]
to shape[-1, 2 * out_channels]
. If set toNone
, \(g\) will be implemented as a single linear layer. (default:None
)act (callable, optional) – Activation function \(\sigma\). (default:
torch.nn.ReLU()
)aggr (string, optional) – The aggregation scheme to use (
"add"
,"mean"
,"max"
). (default:"mean"
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge types \((|\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite
- class SuperGATConv(in_channels: int, out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, bias: bool = True, attention_type: str = 'MX', neg_sample_ratio: float = 0.5, edge_sample_ratio: float = 1.0, is_undirected: bool = False, **kwargs)[source]
The self-supervised graph attentional operator from the “How to Find Your Friendly Neighborhood: Graph Attention Design with Self-Supervision” paper
\[\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},\]where the two types of attention \(\alpha_{i,j}^{\mathrm{MX\ or\ SD}}\) are computed as:
\[ \begin{align}\begin{aligned}\alpha_{i,j}^{\mathrm{MX\ or\ SD}} &= \frac{ \exp\left(\mathrm{LeakyReLU}\left( e_{i,j}^{\mathrm{MX\ or\ SD}} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( e_{i,k}^{\mathrm{MX\ or\ SD}} \right)\right)}\\e_{i,j}^{\mathrm{MX}} &= \mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] \cdot \sigma \left( \left( \mathbf{\Theta}\mathbf{x}_i \right)^{\top} \mathbf{\Theta}\mathbf{x}_j \right)\\e_{i,j}^{\mathrm{SD}} &= \frac{ \left( \mathbf{\Theta}\mathbf{x}_i \right)^{\top} \mathbf{\Theta}\mathbf{x}_j }{ \sqrt{d} }\end{aligned}\end{align} \]The self-supervised task is a link prediction using the attention values as input to predict the likelihood \(\phi_{i,j}^{\mathrm{MX\ or\ SD}}\) that an edge exists between nodes:
\[ \begin{align}\begin{aligned}\phi_{i,j}^{\mathrm{MX}} &= \sigma \left( \left( \mathbf{\Theta}\mathbf{x}_i \right)^{\top} \mathbf{\Theta}\mathbf{x}_j \right)\\\phi_{i,j}^{\mathrm{SD}} &= \sigma \left( \frac{ \left( \mathbf{\Theta}\mathbf{x}_i \right)^{\top} \mathbf{\Theta}\mathbf{x}_j }{ \sqrt{d} } \right)\end{aligned}\end{align} \]Note
For an example of using SuperGAT, see examples/super_gat.py.
- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
heads (int, optional) – Number of multi-head-attentions. (default:
1
)concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
)negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default:
0.2
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)attention_type (string, optional) – Type of attention to use. (
'MX'
,'SD'
). (default:'MX'
)neg_sample_ratio (float, optional) – The ratio of the number of sampled negative edges to the number of positive edges. (default:
0.5
)edge_sample_ratio (float, optional) – The ratio of samples to use for training among the number of training edges. (default:
1.0
)is_undirected (bool, optional) – Whether the input graph is undirected. If not given, will be automatically computed with the input graph when negative sampling is performed. (default:
False
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), negative edge indices \((2, |\mathcal{E}^{(-)}|)\) (optional)
output: node features \((|\mathcal{V}|, H * F_{out})\)
- forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], neg_edge_index: Optional[Tensor] = None, batch: Optional[Tensor] = None) Tensor [source]
- Parameters
neg_edge_index (Tensor, optional) – The negative edges to train against. If not given, uses negative sampling to calculate negative edges. (default:
None
)
- negative_sampling(edge_index: Tensor, num_nodes: int, batch: Optional[Tensor] = None) Tensor [source]
- class FAConv(channels: int, eps: float = 0.1, dropout: float = 0.0, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, **kwargs)[source]
The Frequency Adaptive Graph Convolution operator from the “Beyond Low-Frequency Information in Graph Convolutional Networks” paper
\[\mathbf{x}^{\prime}_i= \epsilon \cdot \mathbf{x}^{(0)}_i + \sum_{j \in \mathcal{N}(i)} \frac{\alpha_{i,j}}{\sqrt{d_i d_j}} \mathbf{x}_{j}\]where \(\mathbf{x}^{(0)}_i\) and \(d_i\) denote the initial feature representation and node degree of node \(i\), respectively. The attention coefficients \(\alpha_{i,j}\) are computed as
\[\mathbf{\alpha}_{i,j} = \textrm{tanh}(\mathbf{a}^{\top}[\mathbf{x}_i, \mathbf{x}_j])\]based on the trainable parameter vector \(\mathbf{a}\).
- Parameters
channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.eps (float, optional) – \(\epsilon\)-value. (default:
0.1
)dropout (float, optional) – Dropout probability of the normalized coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
).cached (bool, optional) – If set to
True
, the layer will cache the computation of \(\sqrt{d_i d_j}\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)normalize (bool, optional) – Whether to add self-loops (if
add_self_loops
isTrue
) and compute symmetric normalization coefficients on the fly. If set toFalse
,edge_weight
needs to be provided in the layer’sforward()
method. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F)\), initial node features \((|\mathcal{V}|, F)\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F)\) or \(((|\mathcal{V}|, F), ((2, |\mathcal{E}|), (|\mathcal{E}|)))\) if
return_attention_weights=True
- class EGConv(in_channels: int, out_channels: int, aggregators: List[str] = ['symnorm'], num_heads: int = 8, num_bases: int = 4, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]
The Efficient Graph Convolution from the “Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions” paper.
Its node-wise formulation is given by:
\[\mathbf{x}_i^{\prime} = {\LARGE ||}_{h=1}^H \sum_{\oplus \in \mathcal{A}} \sum_{b = 1}^B w_{i, h, \oplus, b} \; \underset{j \in \mathcal{N}(i) \cup \{i\}}{\bigoplus} \mathbf{W}_b \mathbf{x}_{j}\]with \(\mathbf{W}_b\) denoting a basis weight, \(\oplus\) denoting an aggregator, and \(w\) denoting per-vertex weighting coefficients across different heads, bases and aggregators.
EGC retains \(\mathcal{O}(|\mathcal{V}|)\) memory usage, making it a sensible alternative to
GCNConv
,SAGEConv
orGINConv
.Note
For an example of using
EGConv
, see examples/egc.py.- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
aggregators (List[str], optional) – Aggregators to be used. Supported aggregators are
"sum"
,"mean"
,"symnorm"
,"max"
,"min"
,"std"
,"var"
. Multiple aggregators can be used to improve the performance. (default:["symnorm"]
)num_heads (int, optional) – Number of heads \(H\) to use. Must have
out_channels % num_heads == 0
. It is recommended to setnum_heads >= num_bases
. (default:8
)num_bases (int, optional) – Number of basis weights \(B\) to use. (default:
4
)cached (bool, optional) – If set to
True
, the layer will cache the computation of the edge index with added self loops on first execution, along with caching the calculation of the symmetric normalized edge weights if the"symnorm"
aggregator is being used. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\)
- class PDNConv(in_channels: int, out_channels: int, edge_dim: int, hidden_channels: int, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, **kwargs)[source]
The pathfinder discovery network convolutional operator from the “Pathfinder Discovery Networks for Neural Message Passing” paper
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(v) \cup \{i\}}f_{\Theta}(\textbf{e}_{(j,i)}) \cdot f_{\Omega}(\mathbf{x}_{j})\]where \(z_{i,j}\) denotes the edge feature vector from source node \(j\) to target node \(i\), and \(\mathbf{x}_{j}\) denotes the node feature vector of node \(j\).
- Parameters
in_channels (int) – Size of each input sample.
out_channels (int) – Size of each output sample.
edge_dim (int) – Edge feature dimensionality.
hidden_channels (int) – Hidden edge feature dimensionality.
add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)normalize (bool, optional) – Whether to add self-loops and compute symmetric normalization coefficients on the fly. (default:
True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)
- class GeneralConv(in_channels: Union[int, Tuple[int, int]], out_channels: Optional[int], in_edge_channels: Optional[int] = None, aggr: str = 'add', skip_linear: str = False, directed_msg: bool = True, heads: int = 1, attention: bool = False, attention_type: str = 'additive', l2_normalize: bool = False, bias: bool = True, **kwargs)[source]
A general GNN layer adapted from the “Design Space for Graph Neural Networks” paper.
- Parameters
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
in_edge_channels (int, optional) – Size of each input edge. (default:
None
)aggr (string, optional) – The aggregation scheme to use (
"add"
,"mean"
,"max"
). (default:"mean"
)skip_linear (bool, optional) – Whether apply linear function in skip connection. (default:
False
)directed_msg (bool, optional) – If message passing is directed; otherwise, message passing is bi-directed. (default:
True
)heads (int, optional) – Number of message passing ensembles. If
heads > 1
, the GNN layer will output an ensemble of multiple messages. If attention is used (attention=True
), this corresponds to multi-head attention. (default:1
)attention (bool, optional) – Whether to add attention to message computation. (default:
False
)attention_type (str, optional) – Type of attention:
"additive"
,"dot_product"
. (default:"additive"
)l2_normalize (bool, optional) – If set to
True
, output features will be \(\ell_2\)-normalized, i.e., \(\frac{\mathbf{x}^{\prime}_i} {\| \mathbf{x}^{\prime}_i \|_2}\). (default:False
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge attributes \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- class HGTConv(in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Tuple[List[str], List[Tuple[str, str, str]]], heads: int = 1, group: str = 'sum', **kwargs)[source]
The Heterogeneous Graph Transformer (HGT) operator from the “Heterogeneous Graph Transformer” paper.
Note
For an example of using HGT, see examples/hetero/hgt_dblp.py.
- Parameters
in_channels (int or Dict[str, int]) – Size of each input sample of every node type, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See
torch_geometric.data.HeteroData.metadata()
for more information.heads (int, optional) – Number of multi-head-attentions. (default:
1
)group (string, optional) – The aggregation scheme to use for grouping node embeddings generated by different relations. (
"sum"
,"mean"
,"min"
,"max"
). (default:"sum"
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- forward(x_dict: Dict[str, Tensor], edge_index_dict: Union[Dict[Tuple[str, str, str], Tensor], Dict[Tuple[str, str, str], SparseTensor]]) Dict[str, Optional[Tensor]] [source]
- Parameters
x_dict (Dict[str, Tensor]) – A dictionary holding input node features for each individual node type.
edge_index_dict (Dict[str, Union[Tensor, SparseTensor]]) – A dictionary holding graph connectivity information for each individual edge type, either as a
torch.LongTensor
of shape[2, num_edges]
or atorch_sparse.SparseTensor
.
- Return type
Dict[str, Optional[Tensor]]
- The output node embeddings for each node type. In case a node type does not receive any message, its output will be set toNone
.
- class HEATConv(in_channels: int, out_channels: int, num_node_types: int, num_edge_types: int, edge_type_emb_dim: int, edge_dim: int, edge_attr_emb_dim: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, root_weight: bool = True, bias: bool = True, **kwargs)[source]
The heterogeneous edge-enhanced graph attentional operator from the “Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction” paper, which enhances
GATConv
by:type-specific transformations of nodes of different types
edge type and edge feature incorporation, in which edges are assumed to have different types but contain the same kind of attributes
- Parameters
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
num_node_types (int) – The number of node types.
num_edge_types (int) – The number of edge types.
edge_type_emb_dim (int) – The embedding size of edge types.
edge_dim (int) – Edge feature dimensionality.
edge_attr_emb_dim (int) – The embedding size of edge features.
heads (int, optional) – Number of multi-head-attentions. (default:
1
)concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
)negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default:
0.2
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)root_weight (bool, optional) – If set to
False
, the layer will not add transformed root node features to the output. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), node types \((|\mathcal{V}|)\), edge types \((|\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)
- class HeteroConv(convs: Dict[Tuple[str, str, str], Module], aggr: Optional[str] = 'sum')[source]
A generic wrapper for computing graph convolution on heterogeneous graphs. This layer will pass messages from source nodes to target nodes based on the bipartite GNN layer given for a specific edge type. If multiple relations point to the same destination, their results will be aggregated according to
aggr
. In comparison totorch_geometric.nn.to_hetero()
, this layer is especially useful if you want to apply different message passing modules for different edge types.hetero_conv = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(-1, 64), ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), ('paper', 'written_by', 'author'): GATConv((-1, -1), 64), }, aggr='sum') out_dict = hetero_conv(x_dict, edge_index_dict) print(list(out_dict.keys())) >>> ['paper', 'author']
- Parameters
convs (Dict[Tuple[str, str, str], Module]) – A dictionary holding a bipartite
MessagePassing
layer for each individual edge type.aggr (string, optional) – The aggregation scheme to use for grouping node embeddings generated by different relations. (
"sum"
,"mean"
,"min"
,"max"
,None
). (default:"sum"
)
- forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Union[Tensor, SparseTensor]], *args_dict, **kwargs_dict) Dict[str, Tensor] [source]
- Parameters
x_dict (Dict[str, Tensor]) – A dictionary holding node feature information for each individual node type.
edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – A dictionary holding graph connectivity information for each individual edge type.
*args_dict (optional) – Additional forward arguments of invididual
torch_geometric.nn.conv.MessagePassing
layers.**kwargs_dict (optional) – Additional forward arguments of individual
torch_geometric.nn.conv.MessagePassing
layers. For example, if a specific GNN layer at edge typeedge_type
expects edge attributesedge_attr
as a forward argument, then you can pass them toforward()
viaedge_attr_dict = { edge_type: edge_attr }
.
- class HANConv(in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Tuple[List[str], List[Tuple[str, str, str]]], heads: int = 1, negative_slope=0.2, dropout: float = 0.0, **kwargs)[source]
The Heterogenous Graph Attention Operator from the “Heterogenous Graph Attention Network” paper.
Note
For an example of using HANConv, see examples/hetero/han_imdb.py.
- Parameters
in_channels (int or Dict[str, int]) – Size of each input sample of every node type, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See
torch_geometric.data.HeteroData.metadata()
for more information.heads (int, optional) – Number of multi-head-attentions. (default:
1
)negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default:
0.2
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Union[Tensor, SparseTensor]], return_semantic_attention_weights: bool = False) Union[Dict[str, Optional[Tensor]], Tuple[Dict[str, Optional[Tensor]], Dict[str, Optional[Tensor]]]] [source]
- Parameters
x_dict (Dict[str, Tensor]) – A dictionary holding input node features for each individual node type.
edge_index_dict (Dict[str, Union[Tensor, SparseTensor]]) – A dictionary holding graph connectivity information for each individual edge type, either as a
torch.LongTensor
of shape[2, num_edges]
or atorch_sparse.SparseTensor
.return_semantic_attention_weights (bool, optional) – If set to
True
, will additionally return the semantic-level attention weights for each destination node type. (default:False
)
- class LGConv(normalize: bool = True, **kwargs)[source]
The Light Graph Convolution (LGC) operator from the “LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation” paper
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \frac{e_{j,i}}{\sqrt{\deg(i)\deg(j)}} \mathbf{x}_j\]- Parameters
normalize (bool, optional) – If set to
False
, output features will not be normalized via symmetric normalization. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F)\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F)\)
- class MetaLayer(edge_model: Optional[Module] = None, node_model: Optional[Module] = None, global_model: Optional[Module] = None)[source]
A meta layer for building any kind of graph network, inspired by the “Relational Inductive Biases, Deep Learning, and Graph Networks” paper.
A graph network takes a graph as input and returns an updated graph as output (with same connectivity). The input graph has node features
x
, edge featuresedge_attr
as well as global-level featuresu
. The output graph has the same structure, but updated features.Edge features, node features as well as global features are updated by calling the modules
edge_model
,node_model
andglobal_model
, respectively.To allow for batch-wise graph processing, all callable functions take an additional argument
batch
, which determines the assignment of edges or nodes to their specific graphs.- Parameters
edge_model (Module, optional) – A callable which updates a graph’s edge features based on its source and target node features, its current edge features and its global features. (default:
None
)node_model (Module, optional) – A callable which updates a graph’s node features based on its current node features, its graph connectivity, its edge features and its global features. (default:
None
)global_model (Module, optional) – A callable which updates a graph’s global features based on its node features, its graph connectivity, its edge features and its current global features.
from torch.nn import Sequential as Seq, Linear as Lin, ReLU from torch_scatter import scatter_mean from torch_geometric.nn import MetaLayer class EdgeModel(torch.nn.Module): def __init__(self): super().__init__() self.edge_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) def forward(self, src, dest, edge_attr, u, batch): # src, dest: [E, F_x], where E is the number of edges. # edge_attr: [E, F_e] # u: [B, F_u], where B is the number of graphs. # batch: [E] with max entry B - 1. out = torch.cat([src, dest, edge_attr, u[batch]], 1) return self.edge_mlp(out) class NodeModel(torch.nn.Module): def __init__(self): super().__init__() self.node_mlp_1 = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) self.node_mlp_2 = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. row, col = edge_index out = torch.cat([x[row], edge_attr], dim=1) out = self.node_mlp_1(out) out = scatter_mean(out, col, dim=0, dim_size=x.size(0)) out = torch.cat([x, out, u[batch]], dim=1) return self.node_mlp_2(out) class GlobalModel(torch.nn.Module): def __init__(self): super().__init__() self.global_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. out = torch.cat([u, scatter_mean(x, batch, dim=0)], dim=1) return self.global_mlp(out) op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel()) x, edge_attr, u = op(x, edge_index, edge_attr, u, batch)
Aggregation Operators
Aggregation functions play an important role in the message passing framework and the readout functions of Graph Neural Networks. Specifically, many works in the literature (Hamilton et al. (2017), Xu et al. (2018), Corso et al. (2020), Li et al. (2020), Tailor et al. (2021)) demonstrate that the choice of aggregation functions contributes significantly to the representational power and performance of the model. For example, mean aggregation captures the distribution (or proportions) of elements, max aggregation proves to be advantageous to identify representative elements, and sum aggregation enables the learning of structural graph properties (Xu et al. (2018)). Recent works also show that using multiple aggregations (Corso et al. (2020), Tailor et al. (2021)) and learnable aggregations (Li et al. (2020)) can potentially provide substantial improvements. Another line of research studies optimization-based and implicitly-defined aggregations (Bartunov et al. (2022)).
To facilitate further experimentation and unify the concepts of aggregation within GNNs across both MessagePassing
and global readouts, we have made the concept of Aggregation
a first-class principle in PyG.
As of now, PyG provides support for various aggregations — from simple ones (e.g., mean
, max
, sum
), to advanced ones (e.g., median
, var
, std
), learnable ones (e.g., SoftmaxAggregation
, PowerMeanAggregation
), and exotic ones (e.g., LSTMAggregation
, SortAggregation
, EquilibriumAggregation
):
from torch_geometric.nn import aggr
# Simple aggregations:
mean_aggr = aggr.MeanAggregation()
max_aggr = aggr.MaxAggregation()
# Advanced aggregations:
median_aggr = aggr.MedianAggregation()
# Learnable aggregations:
softmax_aggr = aggr.SoftmaxAggregation(learn=True)
powermean_aggr = aggr.PowerMeanAggregation(learn=True)
# Exotic aggregations:
lstm_aggr = aggr.LSTMAggregation(in_channels=..., out_channels=...)
sort_aggr = aggr.SortAggregation(k=4)
We can then easily apply these aggregations over a batch of sets of potentially varying size.
For this, an index
vector defines the mapping from input elements to their location in the output:
# Feature matrix holding 1000 elements with 64 features each:
x = torch.randn(1000, 64)
# Randomly assign elements to 100 sets:
index = torch.randint(0, 100, (1000, ))
output = mean_aggr(x, index) # Output shape: [100, 64]
Notably, all aggregations share the same set of forward arguments, as described in detail in torch_geometric.nn.aggr.Aggregation
.
Each of the provided aggregations can be used within MessagePassing
as well as for hierachical/global pooling to obtain graph-level representations:
import torch
from torch_geometric.nn import MessagePassing
class MyConv(MessagePassing):
def __init__(self, ...):
# Use a learnable softmax neighborhood aggregation:
super().__init__(aggr=aggr.SoftmaxAggregation(learn=True))
def forward(self, x, edge_index):
....
class MyGNN(torch.nn.Module)
def __init__(self, ...):
super().__init__()
self.conv = MyConv(...)
# Use a global sort aggregation:
self.global_pool = aggr.SortAggregation(k=4)
self.classifier = torch.nn.Linear(...)
def foward(self, x, edge_index, batch):
x = self.conv(x, edge_index).relu()
x = self.global_pool(x, batch)
x = self.classifier(x)
return x
In addition, the aggregation package of PyG introduces two new concepts:
First, aggregations can be resolved from pure strings via a lookup table, following the design principles of the class-resolver library, e.g., by simply passing in "median"
to the MessagePassing
module.
This will automatically resolve to the MedianAggregation
class:
class MyConv(MessagePassing):
def __init__(self, ...):
super().__init__(aggr="median")
Secondly, multiple aggregations can be combined and stacked via the MultiAggregation
module in order to enhance the representational power of GNNs (Corso et al. (2020), Tailor et al. (2021)):
class MyConv(MessagePassing):
def __init__(self, ...):
# Combines a set of aggregations and concatenates their results,
# i.e. its output will be `[num_nodes, 3 * out_channels]` here.
# Note that the interface also supports automatic resolution.
super().__init__(aggr=aggr.MultiAggregation(
['mean', 'std', aggr.SoftmaxAggregation(learn=True)]))
Importantly, MultiAggregation
provides various options to combine the outputs of its underlying aggegations (e.g., using concatenation, summation, attention, …) via its mode
argument.
The default mode
performs concatenation ("cat"
).
For combining via attention, we need to additionally specify the in_channels
out_channels
, and num_heads
:
multi_aggr = aggr.MultiAggregation(['mean', 'std'], in_channels=64,
out_channels=64, num_heads=4))
If aggregations are given as a list, they will be automatically resolved to a MultiAggregation
, e.g., aggr=['mean', 'std', 'median']
.
Finally, we added full support for customization of aggregations into the SAGEConv
layer — simply override its aggr
argument and utilize the power of aggregation within your GNN.
An abstract base class for implementing custom aggregations. |
|
Performs aggregations with one or more aggregators and combines aggregated results, as described in the "Principal Neighbourhood Aggregation for Graph Nets" and "Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions" papers. |
|
An aggregation operator that sums up features across a set of elements |
|
An aggregation operator that averages features across a set of elements |
|
An aggregation operator that takes the feature-wise maximum across a set of elements |
|
An aggregation operator that takes the feature-wise minimum across a set of elements |
|
An aggregation operator that multiples features across a set of elements |
|
An aggregation operator that takes the feature-wise variance across a set of elements |
|
An aggregation operator that takes the feature-wise standard deviation across a set of elements |
|
The softmax aggregation operator based on a temperature term, as described in the "DeeperGCN: All You Need to Train Deeper GCNs" paper |
|
The powermean aggregation operator based on a power term, as described in the "DeeperGCN: All You Need to Train Deeper GCNs" paper |
|
An aggregation operator that returns the feature-wise median of a set. |
|
An aggregation operator that returns the feature-wise \(q\)-th quantile of a set \(\mathcal{X}\). |
|
Performs LSTM-style aggregation in which the elements to aggregate are interpreted as a sequence, as described in the "Inductive Representation Learning on Large Graphs" paper. |
|
The Set2Set aggregation operator based on iterative content-based attention, as described in the "Order Matters: Sequence to sequence for Sets" paper |
|
Combines one or more aggregators and transforms its output with one or more scalers as introduced in the "Principal Neighbourhood Aggregation for Graph Nets" paper. |
|
The pooling operator from the "An End-to-End Deep Learning Architecture for Graph Classification" paper, where node features are sorted in descending order based on their last feature channel. |
|
The Graph Multiset Transformer pooling operator from the "Accurate Learning of Graph Representations with Graph Multiset Pooling" paper. |
|
The soft attention aggregation layer from the "Graph Matching Networks for Learning the Similarity of Graph Structured Objects" paper |
|
The equilibrium aggregation layer from the "Equilibrium Aggregation: Encoding Sets via Optimization" paper. |
- class Aggregation[source]
An abstract base class for implementing custom aggregations.
Aggregation can be either performed via an
index
vector, which defines the mapping from input elements to their location in the output:Notably,
index
does not have to be sorted:# Feature matrix holding 10 elements with 64 features each: x = torch.randn(10, 64) # Assign each element to one of three sets: index = torch.tensor([0, 0, 1, 0, 2, 0, 2, 1, 0, 2]) output = aggr(x, index) # Output shape: [3, 64]
Alternatively, aggregation can be achieved via a “compressed” index vector called
ptr
. Here, elements within the same set need to be grouped together in the input, andptr
defines their boundaries:# Feature matrix holding 10 elements with 64 features each: x = torch.randn(10, 64) # Define the boundary indices for three sets: ptr = torch.tensor([0, 4, 7, 10]) output = aggr(x, ptr=ptr) # Output shape: [4, 64]
Note that at least one of
index
orptr
must be defined.- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or edge features \((|\mathcal{E}|, F_{in})\), index vector \((|\mathcal{V}|)\) or \((|\mathcal{E}|)\),
output: graph features \((|\mathcal{G}|, F_{out})\) or node features \((|\mathcal{V}|, F_{out})\)
- forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor [source]
- Parameters
x (torch.Tensor) – The source tensor.
index (torch.LongTensor, optional) – The indices of elements for applying the aggregation. One of
index
orptr
must be defined. (default:None
)ptr (torch.LongTensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of
index
orptr
must be defined. (default:None
)dim_size (int, optional) – The size of the output tensor at dimension
dim
after aggregation. (default:None
)dim (int, optional) – The dimension in which to aggregate. (default:
-2
)
- static set_validate_args(value: bool)[source]
Sets whether validation is enabled or disabled.
The default behavior mimics Python’s
assert`
statement: validation is on by default, but is disabled if Python is run in optimized mode (viapython -O
). Validation may be expensive, so you may want to disable it once a model is working.- Parameters
value (bool) – Whether to enable validation.
- class MultiAggregation(aggrs: List[Union[Aggregation, str]], aggrs_kwargs: Optional[List[Dict[str, Any]]] = None, mode: Optional[str] = 'cat', mode_kwargs: Optional[Dict[str, Any]] = None)[source]
Performs aggregations with one or more aggregators and combines aggregated results, as described in the “Principal Neighbourhood Aggregation for Graph Nets” and “Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions” papers.
- Parameters
aggrs (list) – The list of aggregation schemes to use.
aggrs_kwargs (dict, optional) – Arguments passed to the respective aggregation function in case it gets automatically resolved. (default:
None
)mode (string, optional) – The combine mode to use for combining aggregated results from multiple aggregations (
"cat"
,"proj"
,"sum"
,"mean"
,"max"
,"min"
,"logsumexp"
,"std"
,"var"
,"attn"
). (default:"cat"
)mode_kwargs (dict, optional) – Arguments passed for the combine
mode
. When"proj"
or"attn"
is used as the combinemode
,in_channels
(int or tuple) andout_channels
(int) are needed to be specified respectively for the size of each input sample to combine from the respective aggregation outputs and the size of each output sample after combination. When"attn"
mode is used,num_heads
(int) is needed to be specified for the number of parallel attention heads. (default:None
)
- class SumAggregation[source]
An aggregation operator that sums up features across a set of elements
\[\mathrm{sum}(\mathcal{X}) = \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i.\]
- class MeanAggregation[source]
An aggregation operator that averages features across a set of elements
\[\mathrm{mean}(\mathcal{X}) = \frac{1}{|\mathcal{X}|} \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i.\]
- class MaxAggregation[source]
An aggregation operator that takes the feature-wise maximum across a set of elements
\[\mathrm{max}(\mathcal{X}) = \max_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i.\]
- class MinAggregation[source]
An aggregation operator that takes the feature-wise minimum across a set of elements
\[\mathrm{min}(\mathcal{X}) = \min_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i.\]
- class MulAggregation[source]
An aggregation operator that multiples features across a set of elements
\[\mathrm{mul}(\mathcal{X}) = \prod_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i.\]
- class VarAggregation(semi_grad: bool = False)[source]
An aggregation operator that takes the feature-wise variance across a set of elements
\[\mathrm{var}(\mathcal{X}) = \mathrm{mean}(\{ \mathbf{x}_i^2 : x \in \mathcal{X} \}) - \mathrm{mean}(\mathcal{X})^2.\]