torch_geometric.nn

Convolutional Layers

class MessagePassing(aggr: Optional[str] = 'add', flow: str = 'source_to_target', node_dim: int = - 2)[source]

Base class for creating message passing layers of the form

\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),\]

where \(\square\) denotes a differentiable, permutation invariant function, e.g., sum, mean or max, and \(\gamma_{\mathbf{\Theta}}\) and \(\phi_{\mathbf{\Theta}}\) denote differentiable functions such as MLPs. See here for the accompanying tutorial.

Parameters
  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max" or None). (default: "add")

  • flow (string, optional) – The flow direction of message passing ("source_to_target" or "target_to_source"). (default: "source_to_target")

  • node_dim (int, optional) – The axis along which to propagate. (default: -2)

aggregate(inputs: torch.Tensor, index: torch.Tensor, ptr: Optional[torch.Tensor] = None, dim_size: Optional[int] = None) → torch.Tensor[source]

Aggregates messages from neighbors as \(\square_{j \in \mathcal{N}(i)}\).

Takes in the output of message computation as first argument and any argument which was initially passed to propagate().

By default, this function will delegate its call to scatter functions that support “add”, “mean” and “max” operations as specified in __init__() by the aggr argument.

jittable(typing: Optional[str] = None)[source]

Analyzes the MessagePassing instance and produces a new jittable module.

Parameters

typing (string, optional) – If given, will generate a concrete instance with forward() types based on typing, e.g.: "(Tensor, Optional[Tensor]) -> Tensor".

message(x_j: torch.Tensor) → torch.Tensor[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

message_and_aggregate(adj_t: torch_sparse.tensor.SparseTensor) → torch.Tensor[source]

Fuses computations of message() and aggregate() into a single function. If applicable, this saves both time and memory since messages do not explicitly need to be materialized. This function will only gets called in case it is implemented and propagation takes place based on a torch_sparse.SparseTensor.

propagate(edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], size: Optional[Tuple[int, int]] = None, **kwargs)[source]

The initial call to start propagating messages.

Parameters
  • adj (Tensor or SparseTensor) – A torch.LongTensor or a torch_sparse.SparseTensor that defines the underlying graph connectivity/message passing flow. edge_index holds the indices of a general (sparse) assignment matrix of shape [N, M]. If edge_index is of type torch.LongTensor, its shape must be defined as [2, num_messages], where messages from nodes in edge_index[0] are sent to nodes in edge_index[1] (in case flow="source_to_target"). If edge_index is of type torch_sparse.SparseTensor, its sparse indices (row, col) should relate to row = edge_index[1] and col = edge_index[0]. The major difference between both formats is that we need to input the transposed sparse adjacency matrix into propagate().

  • size (tuple, optional) – The size (N, M) of the assignment matrix in case edge_index is a LongTensor. If set to None, the size will be automatically inferred and assumed to be quadratic. This argument is ignored in case edge_index is a torch_sparse.SparseTensor. (default: None)

  • **kwargs – Any additional data which is needed to construct and aggregate messages, and to update node embeddings.

update(inputs: torch.Tensor) → torch.Tensor[source]

Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().

class AGNNConv(requires_grad: bool = True, add_self_loops: bool = True, **kwargs)[source]

The graph attentional propagation layer from the “Attention-based Graph Neural Network for Semi-Supervised Learning” paper

\[\mathbf{X}^{\prime} = \mathbf{P} \mathbf{X},\]

where the propagation matrix \(\mathbf{P}\) is computed as

\[P_{i,j} = \frac{\exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_j))} {\sum_{k \in \mathcal{N}(i)\cup \{ i \}} \exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_k))}\]

with trainable parameter \(\beta\).

Parameters
  • requires_grad (bool, optional) – If set to False, \(\beta\) will not be trainable. (default: True)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: torch.Tensor, edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor]) → torch.Tensor[source]
reset_parameters()[source]
class APPNP(K: int, alpha: float, add_self_loops: bool = True, **kwargs)[source]

The approximate personalized propagation of neural predictions layer from the “Predict then Propagate: Graph Neural Networks meet Personalized PageRank” paper

\[ \begin{align}\begin{aligned}\mathbf{X}^{(0)} &= \mathbf{X}\\\mathbf{X}^{(k)} &= (1 - \alpha) \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X}^{(k-1)} + \alpha \mathbf{X}^{(0)}\\\mathbf{X}^{\prime} &= \mathbf{X}^{(K)},\end{aligned}\end{align} \]

where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix.

Parameters
  • K (int) – Number of iterations \(K\).

  • alpha (float) – Teleport probability \(\alpha\).

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: torch.Tensor, edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_weight: Optional[torch.Tensor] = None) → torch.Tensor[source]
class ARMAConv(in_channels: int, out_channels: int, num_stacks: int = 1, num_layers: int = 1, shared_weights: bool = False, act: Callable = <function relu>, dropout: float = 0.0, bias: bool = True, **kwargs)[source]

The ARMA graph convolutional operator from the “Graph Neural Networks with Convolutional ARMA Filters” paper

\[\mathbf{X}^{\prime} = \frac{1}{K} \sum_{k=1}^K \mathbf{X}_k^{(T)},\]

with \(\mathbf{X}_k^{(T)}\) being recursively defined by

\[\mathbf{X}_k^{(t+1)} = \sigma \left( \mathbf{\hat{L}} \mathbf{X}_k^{(t)} \mathbf{W} + \mathbf{X}^{(0)} \mathbf{V} \right),\]

where \(\mathbf{\hat{L}} = \mathbf{I} - \mathbf{L} = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\) denotes the modified Laplacian \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\).

Parameters
  • in_channels (int) – Size of each input sample \(\mathbf{x}^{(t)}\).

  • out_channels (int) – Size of each output sample \(\mathbf{x}^{(t+1)}\).

  • num_stacks (int, optional) – Number of parallel stacks \(K\). (default: 1).

  • num_layers (int, optional) – Number of layers \(T\). (default: 1)

  • act (callable, optional) – Activation function \(\sigma\). (default: torch.nn.functional.ReLU())

  • shared_weights (int, optional) – If set to True the layers in each stack will share the same parameters. (default: False)

  • dropout (float, optional) – Dropout probability of the skip connection. (default: 0.)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: torch.Tensor, edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_weight: Optional[torch.Tensor] = None) → torch.Tensor[source]
reset_parameters()[source]
class CGConv(channels: Union[int, Tuple[int, int]], dim: int = 0, aggr: str = 'add', batch_norm: bool = False, bias: bool = True, **kwargs)[source]

The crystal graph convolutional operator from the “Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties” paper

\[\mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \sigma \left( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right) \odot g \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s \right)\]

where \(\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{i,j} ]\) denotes the concatenation of central node features, neighboring node features and edge features. In addition, \(\sigma\) and \(g\) denote the sigmoid and softplus functions, respectively.

Parameters
  • channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities.

  • dim (int, optional) – Edge feature dimensionality. (default: 0)

  • aggr (string, optional) – The aggregation operator to use ("add", "mean", "max"). (default: "add")

  • batch_norm (bool, optional) – If set to True, will make use of batch normalization. (default: False)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_attr: Optional[torch.Tensor] = None, size: Optional[Tuple[int, int]] = None) → torch.Tensor[source]
reset_parameters()[source]
class ChebConv(in_channels, out_channels, K, normalization='sym', bias=True, **kwargs)[source]

The chebyshev spectral graph convolutional operator from the “Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering” paper

\[\mathbf{X}^{\prime} = \sum_{k=1}^{K} \mathbf{Z}^{(k)} \cdot \mathbf{\Theta}^{(k)}\]

where \(\mathbf{Z}^{(k)}\) is computed recursively by

\[ \begin{align}\begin{aligned}\mathbf{Z}^{(1)} &= \mathbf{X}\\\mathbf{Z}^{(2)} &= \mathbf{\hat{L}} \cdot \mathbf{X}\\\mathbf{Z}^{(k)} &= 2 \cdot \mathbf{\hat{L}} \cdot \mathbf{Z}^{(k-1)} - \mathbf{Z}^{(k-2)}\end{aligned}\end{align} \]

and \(\mathbf{\hat{L}}\) denotes the scaled and normalized Laplacian \(\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}\).

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • K (int) – Chebyshev filter size \(K\).

  • normalization (str, optional) –

    The normalization scheme for the graph Laplacian (default: "sym"):

    1. None: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)

    2. "sym": Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)

    3. "rw": Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)

    You need to pass lambda_max to the forward() method of this operator in case the normalization is non-symmetric. lambda_max should be a torch.Tensor of size [num_graphs] in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-compute lambda_max via the torch_geometric.transforms.LaplacianLambdaMax transform.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x, edge_index, edge_weight: Optional[torch.Tensor] = None, batch: Optional[torch.Tensor] = None, lambda_max: Optional[torch.Tensor] = None)[source]
reset_parameters()[source]
class ClusterGCNConv(in_channels: int, out_channels: int, diag_lambda: float = 0.0, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]

The ClusterGCN graph convolutional operator from the “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks” paper

\[\mathbf{X}^{\prime} = \left( \mathbf{\hat{A}} + \lambda \cdot \textrm{diag}(\mathbf{\hat{A}}) \right) \mathbf{X} \mathbf{W}_1 + \mathbf{X} \mathbf{W}_2\]

where \(\mathbf{\hat{A}} = {(\mathbf{D} + \mathbf{I})}^{-1}(\mathbf{A} + \mathbf{I})\).

Parameters
  • in_channels (int or tuple) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • diag_lambda – (float, optional): Diagonal enhancement value \(\lambda\). (default: 0.)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: torch.Tensor, edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], size: Optional[Tuple[int, int]] = None) → torch.Tensor[source]
reset_parameters()[source]
class DNAConv(channels: int, heads: int = 1, groups: int = 1, dropout: float = 0.0, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]

The dynamic neighborhood aggregation operator from the “Just Jump: Towards Dynamic Neighborhood Aggregation in Graph Neural Networks” paper

\[\mathbf{x}_v^{(t)} = h_{\mathbf{\Theta}}^{(t)} \left( \mathbf{x}_{v \leftarrow v}^{(t)}, \left\{ \mathbf{x}_{v \leftarrow w}^{(t)} : w \in \mathcal{N}(v) \right\} \right)\]

based on (multi-head) dot-product attention

\[\mathbf{x}_{v \leftarrow w}^{(t)} = \textrm{Attention} \left( \mathbf{x}^{(t-1)}_v \, \mathbf{\Theta}_Q^{(t)}, [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_K^{(t)}, \, [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_V^{(t)} \right)\]

with \(\mathbf{\Theta}_Q^{(t)}, \mathbf{\Theta}_K^{(t)}, \mathbf{\Theta}_V^{(t)}\) denoting (grouped) projection matrices for query, key and value information, respectively. \(h^{(t)}_{\mathbf{\Theta}}\) is implemented as a non-trainable version of torch_geometric.nn.conv.GCNConv.

Note

In contrast to other layers, this operator expects node features as shape [num_nodes, num_layers, channels].

Parameters
  • channels (int) – Size of each input/output sample.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • groups (int, optional) – Number of groups to use for all linear projections. (default: 1)

  • dropout (float, optional) – Dropout probability of attention coefficients. (default: 0.)

  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: torch.Tensor, edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_weight: Optional[torch.Tensor] = None) → torch.Tensor[source]
Parameters

x – The input node features of shape [num_nodes, num_layers, channels].

reset_parameters()[source]
class DynamicEdgeConv(nn: Callable, k: int, aggr: str = 'max', num_workers: int = 1, **kwargs)[source]

The dynamic edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper (see torch_geometric.nn.conv.EdgeConv), where the graph is dynamically constructed using nearest neighbors in the feature space.

Parameters
  • nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps pair-wise concatenated node features x of shape :obj:`[-1, 2 * in_channels] to shape [-1, out_channels], e.g. defined by torch.nn.Sequential.

  • k (int) – Number of nearest neighbors.

  • aggr (string) – The aggregation operator to use ("add", "mean", "max"). (default: "max")

  • num_workers (int) – Number of workers to use for k-NN computation. Has no effect in case batch is not None, or the input lies on the GPU. (default: 1)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], batch: Union[torch.Tensor, None, Tuple[torch.Tensor, torch.Tensor]] = None) → torch.Tensor[source]
reset_parameters()[source]
ECConv

alias of torch_geometric.nn.conv.nn_conv.NNConv

class EdgeConv(nn: Callable, aggr: str = 'max', **kwargs)[source]

The edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}}(\mathbf{x}_i \, \Vert \, \mathbf{x}_j - \mathbf{x}_i),\]

where \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. a MLP.

Parameters
  • nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps pair-wise concatenated node features x of shape [-1, 2 * in_channels] to shape [-1, out_channels], e.g., defined by torch.nn.Sequential.

  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "max")

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor]) → torch.Tensor[source]
reset_parameters()[source]
class FastRGCNConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, num_relations: int, num_bases: Optional[int] = None, num_blocks: Optional[int] = None, aggr: str = 'mean', root_weight: bool = True, bias: bool = True, **kwargs)[source]

See RGCNConv.

forward(x: Union[torch.Tensor, None, Tuple[Optional[torch.Tensor], torch.Tensor]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_type: Optional[torch.Tensor] = None)[source]
class FeaStConv(in_channels: int, out_channels: int, heads: int = 1, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]

The (translation-invariant) feature-steered convolutional operator from the “FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis” paper

\[\mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \sum_{h=1}^H q_h(\mathbf{x}_i, \mathbf{x}_j) \mathbf{W}_h \mathbf{x}_j\]

with \(q_h(\mathbf{x}_i, \mathbf{x}_j) = \mathrm{softmax}_j (\mathbf{u}_h^{\top} (\mathbf{x}_j - \mathbf{x}_i) + c_h)\), where \(H\) denotes the number of attention heads, and \(\mathbf{W}_h\), \(\mathbf{u}_h\) and \(c_h\) are trainable parameters.

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • heads (int, optional) – Number of attention heads \(H\). (default: 1)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor]) → torch.Tensor[source]
reset_parameters()[source]
class GATConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]

The graph attentional operator from the “Graph Attention Networks” paper

\[\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},\]

where the attention coefficients \(\alpha_{i,j}\) are computed as

\[\alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] \right)\right)}.\]
Parameters
  • in_channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], size: Optional[Tuple[int, int]] = None, return_attention_weights=None)[source]
Parameters

return_attention_weights (bool, optional) – If set to True, will additionally return the tuple (edge_index, attention_weights), holding the computed attention weights for each edge. (default: None)

reset_parameters()[source]
class GCNConv(in_channels: int, out_channels: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, **kwargs)[source]

The graph convolutional operator from the “Semi-supervised Classification with Graph Convolutional Networks” paper

\[\mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},\]

where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix.

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • improved (bool, optional) – If set to True, the layer computes \(\mathbf{\hat{A}}\) as \(\mathbf{A} + 2\mathbf{I}\). (default: False)

  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • normalize (bool, optional) – Whether to add self-loops and apply symmetric normalization. (default: True)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: torch.Tensor, edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_weight: Optional[torch.Tensor] = None) → torch.Tensor[source]
reset_parameters()[source]
class GENConv(in_channels: int, out_channels: int, aggr: str = 'softmax', t: float = 1.0, learn_t: bool = False, p: float = 1.0, learn_p: bool = False, msg_norm: bool = False, learn_msg_scale: bool = False, norm: str = 'batch', num_layers: int = 2, eps: float = 1e-07, **kwargs)[source]

The GENeralized Graph Convolution (GENConv) from the “DeeperGCN: All You Need to Train Deeper GCNs” paper. Supports SoftMax & PowerMean aggregation. The message construction is:

\[\mathbf{x}_i^{\prime} = \mathrm{MLP} \left( \mathbf{x}_i + \mathrm{AGG} \left( \left\{ \mathrm{ReLU} \left( \mathbf{x}_j + \mathbf{e_{ji}} \right) +\epsilon : j \in \mathcal{N}(i) \right\} \right) \right)\]

Note

For an example of using GENConv, see examples/ogbn_proteins_deepgcn.py.

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • aggr (str, optional) – The aggregation scheme to use ("softmax", "softmax_sg", "power", "add", "mean", max). (default: "softmax")

  • t (float, optional) – Initial inverse temperature for softmax aggregation. (default: 1.0)

  • learn_t (bool, optional) – If set to True, will learn the value t for softmax aggregation dynamically. (default: False)

  • p (float, optional) – Initial power for power mean aggregation. (default: 1.0)

  • learn_p (bool, optional) – If set to True, will learn the value p for power mean aggregation dynamically. (default: False)

  • msg_norm (bool, optional) – If set to True, will use message normalization. (default: False)

  • learn_msg_scale (bool, optional) – If set to True, will learn the scaling factor of message normalization. (default: False)

  • norm (str, optional) – Norm layer of MLP layers ("batch", "layer", "instance") (default: batch)

  • num_layers (int, optional) – The number of MLP layers. (default: 2)

  • eps (float, optional) – The epsilon value of the message construction function. (default: 1e-7)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.GenMessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_attr: Optional[torch.Tensor] = None, size: Optional[Tuple[int, int]] = None) → torch.Tensor[source]
reset_parameters()[source]
class GINConv(nn: Callable, eps: float = 0.0, train_eps: bool = False, **kwargs)[source]

The graph isomorphism operator from the “How Powerful are Graph Neural Networks?” paper

\[\mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)\]

or

\[\mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right),\]

here \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. an MLP.

Parameters
  • nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features x of shape [-1, in_channels] to shape [-1, out_channels], e.g., defined by torch.nn.Sequential.

  • eps (float, optional) – (Initial) \(\epsilon\)-value. (default: 0.)

  • train_eps (bool, optional) – If set to True, \(\epsilon\) will be a trainable parameter. (default: False)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], size: Optional[Tuple[int, int]] = None) → torch.Tensor[source]
reset_parameters()[source]
class GINEConv(nn: Callable, eps: float = 0.0, train_eps: bool = False, **kwargs)[source]

The modified GINConv operator from the “Strategies for Pre-training Graph Neural Networks” paper

\[\mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathrm{ReLU} ( \mathbf{x}_j + \mathbf{e}_{j,i} ) \right)\]

that is able to incorporate edge features \(\mathbf{e}_{j,i}\) into the aggregation procedure.

Parameters
  • nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features x of shape [-1, in_channels] to shape [-1, out_channels], e.g., defined by torch.nn.Sequential.

  • eps (float, optional) – (Initial) \(\epsilon\)-value. (default: 0.)

  • train_eps (bool, optional) – If set to True, \(\epsilon\) will be a trainable parameter. (default: False)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_attr: Optional[torch.Tensor] = None, size: Optional[Tuple[int, int]] = None) → torch.Tensor[source]
reset_parameters()[source]
class GMMConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, dim: int, kernel_size: int, separate_gaussians: bool = False, aggr: str = 'mean', root_weight: bool = True, bias: bool = True, **kwargs)[source]

The gaussian mixture model convolutional operator from the “Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs” paper

\[\mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \frac{1}{K} \sum_{k=1}^K \mathbf{w}_k(\mathbf{e}_{i,j}) \odot \mathbf{\Theta}_k \mathbf{x}_j,\]

where

\[\mathbf{w}_k(\mathbf{e}) = \exp \left( -\frac{1}{2} {\left( \mathbf{e} - \mathbf{\mu}_k \right)}^{\top} \Sigma_k^{-1} \left( \mathbf{e} - \mathbf{\mu}_k \right) \right)\]

denotes a weighting function based on trainable mean vector \(\mathbf{\mu}_k\) and diagonal covariance matrix \(\mathbf{\Sigma}_k\).

Parameters
  • in_channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • dim (int) – Pseudo-coordinate dimensionality.

  • kernel_size (int) – Number of kernels \(K\).

  • separate_gaussians (bool, optional) – If set to True, will learn separate GMMs for every pair of input and output channel, inspired by traditional CNNs. (default: False)

  • aggr (string, optional) – The aggregation operator to use ("add", "mean", "max"). (default: "mean")

  • root_weight (bool, optional) – If set to False, the layer will not add transformed root node features to the output. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_attr: Optional[torch.Tensor] = None, size: Optional[Tuple[int, int]] = None)[source]
reset_parameters()[source]
class GatedGraphConv(out_channels: int, num_layers: int, aggr: str = 'add', bias: bool = True, **kwargs)[source]

The gated graph convolution operator from the “Gated Graph Sequence Neural Networks” paper

\[ \begin{align}\begin{aligned}\mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0}\\\mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)}\\\mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, \mathbf{h}_i^{(l)})\end{aligned}\end{align} \]

up to representation \(\mathbf{h}_i^{(L)}\). The number of input channels of \(\mathbf{x}_i\) needs to be less or equal than out_channels.

Parameters
  • out_channels (int) – Size of each input sample.

  • num_layers (int) – The sequence length \(L\).

  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "add")

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: torch.Tensor, edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_weight: Optional[torch.Tensor] = None) → torch.Tensor[source]
reset_parameters()[source]
class GraphConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, aggr: str = 'add', bias: bool = True, **kwargs)[source]

The graph neural network operator from the “Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks” paper

\[\mathbf{x}^{\prime}_i = \mathbf{\Theta}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{\Theta}_2 \mathbf{x}_j.\]
Parameters
  • in_channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "add")

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_weight: Optional[torch.Tensor] = None, size: Optional[Tuple[int, int]] = None) → torch.Tensor[source]
reset_parameters()[source]
class GravNetConv(in_channels: int, out_channels: int, space_dimensions: int, propagate_dimensions: int, k: int, num_workers: int = 1, **kwargs)[source]

The GravNet operator from the “Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks” paper, where the graph is dynamically constructed using nearest neighbors. The neighbors are constructed in a learnable low-dimensional projection of the feature space. A second projection of the input feature space is then propagated from the neighbors to each vertex using distance weights that are derived by applying a Gaussian function to the distances.

Parameters
  • in_channels (int) – The number of input channels.

  • out_channels (int) – The number of output channels.

  • space_dimensions (int) – The dimensionality of the space used to construct the neighbors; referred to as \(S\) in the paper.

  • propagate_dimensions (int) – The number of features to be propagated between the vertices; referred to as \(F_{\textrm{LR}}\) in the paper.

  • k (int) – The number of nearest neighbors.

  • num_workers (int) – Number of workers to use for k-NN computation. Has no effect in case batch is not None, or the input lies on the GPU. (default: 1)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], batch: Union[torch.Tensor, None, Tuple[torch.Tensor, torch.Tensor]] = None) → torch.Tensor[source]
reset_parameters()[source]
class HypergraphConv(in_channels, out_channels, use_attention=False, heads=1, concat=True, negative_slope=0.2, dropout=0, bias=True, **kwargs)[source]

The hypergraph convolutional operator from the “Hypergraph Convolution and Hypergraph Attention” paper

\[\mathbf{X}^{\prime} = \mathbf{D}^{-1} \mathbf{H} \mathbf{W} \mathbf{B}^{-1} \mathbf{H}^{\top} \mathbf{X} \mathbf{\Theta}\]

where \(\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}\) is the incidence matrix, \(\mathbf{W}\) is the diagonal hyperedge weight matrix, and \(\mathbf{D}\) and \(\mathbf{B}\) are the corresponding degree matrices.

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • use_attention (bool, optional) – If set to True, attention will be added to this layer. (default: False)

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x, hyperedge_index, hyperedge_weight=None)[source]
Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X}\)

  • hyper_edge_index (LongTensor) – Hyperedge indices from \(\mathbf{H}\).

  • hyperedge_weight (Tensor, optional) – Sparse hyperedge weights from \(\mathbf{W}\). (default: None)

reset_parameters()[source]
class LEConv(in_channels, out_channels, bias=True, **kwargs)[source]

The local extremum graph neural network operator from the “ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations” paper, which finds the importance of nodes with respect to their neighbors using the difference operator:

\[\mathbf{x}^{\prime}_i = \mathbf{x}_i \cdot \mathbf{\Theta}_1 + \sum_{j \in \mathcal{N}(i)} a_{j,i} (\mathbf{x}_i \cdot \mathbf{\Theta}_2 - \mathbf{x}_j \cdot \mathbf{\Theta}_3)\]
Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True).

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x, edge_index, edge_weight=None)[source]
reset_parameters()[source]
class MFConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, max_degree: int = 10, bias=True, **kwargs)[source]

The graph neural network operator from the “Convolutional Networks on Graphs for Learning Molecular Fingerprints” paper

\[\mathbf{x}^{\prime}_i = \mathbf{W}^{(\deg(i))}_1 \mathbf{x}_i + \mathbf{W}^{(\deg(i))}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j\]

which trains a distinct weight matrix for each possible vertex degree.

Parameters
  • in_channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • max_degree (int, optional) – The maximum node degree to consider when updating weights (default: 10)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], size: Optional[Tuple[int, int]] = None) → torch.Tensor[source]
reset_parameters()[source]
class NNConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, nn: Callable, aggr: str = 'add', root_weight: bool = True, bias: bool = True, **kwargs)[source]

The continuous kernel-based convolutional operator from the “Neural Message Passing for Quantum Chemistry” paper. This convolution is also known as the edge-conditioned convolution from the “Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs” paper (see torch_geometric.nn.conv.ECConv for an alias):

\[\mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),\]

where \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. a MLP.

Parameters
  • in_channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps edge features edge_attr of shape [-1, num_edge_features] to shape [-1, in_channels * out_channels], e.g., defined by torch.nn.Sequential.

  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "add")

  • root_weight (bool, optional) – If set to False, the layer will not add the transformed root node features to the output. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_attr: Optional[torch.Tensor] = None, size: Optional[Tuple[int, int]] = None) → torch.Tensor[source]
reset_parameters()[source]
class PNAConv(in_channels: int, out_channels: int, aggregators: List[str], scalers: List[str], deg: torch.Tensor, edge_dim: Optional[int] = None, towers: int = 1, pre_layers: int = 1, post_layers: int = 1, divide_input: bool = False, **kwargs)[source]

The Principal Neighbourhood Aggregation graph convolution operator from the “Principal Neighbourhood Aggregation for Graph Nets” paper

\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \underset{j \in \mathcal{N}(i)}{\bigoplus} h_{\mathbf{\Theta}} \left( \mathbf{x}_i, \mathbf{x}_j \right) \right)\]

with

\[\begin{split}\bigoplus = \underbrace{\begin{bmatrix} 1 \\ S(\mathbf{D}, \alpha=1) \\ S(\mathbf{D}, \alpha=-1) \end{bmatrix} }_{\text{scalers}} \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min \end{bmatrix}}_{\text{aggregators}},\end{split}\]

where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote MLPs.

Note

For an example of using PNAConv, see examples/pna.py.

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • aggregators (list of str) – Set of aggregation function identifiers, namely "sum", "mean", "min", "max", "var" and "std".

  • scalers – (list of str): Set of scaling function identifiers, namely "identity", "amplification", "attenuation", "linear" and "inverse_linear".

  • deg (Tensor) – Histogram of in-degrees of nodes in the training set, used by scalers to normalize.

  • edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default None)

  • towers (int, optional) – Number of towers (default: 1).

  • pre_layers (int, optional) – Number of transformation layers before aggregation (default: 1).

  • post_layers (int, optional) – Number of transformation layers after aggregation (default: 1).

  • divide_input (bool, optional) – Whether the input features should be split between towers or not (default: False).

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: torch.Tensor, edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_attr: Optional[torch.Tensor] = None) → torch.Tensor[source]
reset_parameters()[source]
class PPFConv(local_nn: Optional[Callable] = None, global_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs)[source]

The PPFNet operator from the “PPFNet: Global Context Aware Local Features for Robust 3D Point Matching” paper

\[\mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \| \mathbf{d_{j,i}} \|, \angle(\mathbf{n}_i, \mathbf{d_{j,i}}), \angle(\mathbf{n}_j, \mathbf{d_{j,i}}), \angle(\mathbf{n}_i, \mathbf{n}_j) \right)\]

where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote neural networks, .i.e. MLPs, which takes in node features and torch_geometric.transforms.PointPairFeatures.

Parameters
  • local_nn (torch.nn.Module, optional) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features x and relative spatial coordinates pos_j - pos_i of shape [-1, in_channels + num_dimensions] to shape [-1, out_channels], e.g., defined by torch.nn.Sequential. (default: None)

  • global_nn (torch.nn.Module, optional) – A neural network \(\gamma_{\mathbf{\Theta}}\) that maps aggregated node features of shape [-1, out_channels] to shape [-1, final_out_channels], e.g., defined by torch.nn.Sequential. (default: None)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, None, Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], pos: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], normal: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor]) → torch.Tensor[source]
reset_parameters()[source]
class PointConv(local_nn: Optional[Callable] = None, global_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs)[source]

The PointNet set layer from the “PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation” and “PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space” papers

\[\mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \mathbf{p}_j - \mathbf{p}_i) \right),\]

where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote neural networks, .i.e. MLPs, and \(\mathbf{P} \in \mathbb{R}^{N \times D}\) defines the position of each point.

Parameters
  • local_nn (torch.nn.Module, optional) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features x and relative spatial coordinates pos_j - pos_i of shape [-1, in_channels + num_dimensions] to shape [-1, out_channels], e.g., defined by torch.nn.Sequential. (default: None)

  • global_nn (torch.nn.Module, optional) – A neural network \(\gamma_{\mathbf{\Theta}}\) that maps aggregated node features of shape [-1, out_channels] to shape [-1, final_out_channels], e.g., defined by torch.nn.Sequential. (default: None)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, None, Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], pos: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor]) → torch.Tensor[source]
reset_parameters()[source]
class RGCNConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, num_relations: int, num_bases: Optional[int] = None, num_blocks: Optional[int] = None, aggr: str = 'mean', root_weight: bool = True, bias: bool = True, **kwargs)[source]

The relational graph convolutional operator from the “Modeling Relational Data with Graph Convolutional Networks” paper

\[\mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)} \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,\]

where \(\mathcal{R}\) denotes the set of relations, i.e. edge types. Edge type needs to be a one-dimensional torch.long tensor which stores a relation identifier \(\in \{ 0, \ldots, |\mathcal{R}| - 1\}\) for each edge.

Note

This implementation is as memory-efficient as possible by iterating over each individual relation type. Therefore, it may result in low GPU utilization in case the graph has a large number of relations. As an alternative approach, FastRGCNConv does not iterate over each individual type, but may consume a large amount of memory to compensate. We advise to check out both implementations to see which one fits your needs.

Parameters
  • in_channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities. In case no input features are given, this argument should correspond to the number of nodes in your graph.

  • out_channels (int) – Size of each output sample.

  • num_relations (int) – Number of relations.

  • num_bases (int, optional) – If set to not None, this layer will use the basis-decomposition regularization scheme where num_bases denotes the number of bases to use. (default: None)

  • num_blocks (int, optional) – If set to not None, this layer will use the block-diagonal-decomposition regularization scheme where num_blocks denotes the number of blocks to use. (default: None)

  • aggr (string, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "mean")

  • root_weight (bool, optional) – If set to False, the layer will not add transformed root node features to the output. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, None, Tuple[Optional[torch.Tensor], torch.Tensor]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_type: Optional[torch.Tensor] = None)[source]
Parameters
  • x – The input node features. Can be either a [num_nodes, in_channels] node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings). Furthermore, x can be of type tuple denoting source and destination node features.

  • edge_type – The one-dimensional relation type/index for each edge in edge_index. Should be only None in case edge_index is of type torch_sparse.tensor.SparseTensor. (default: None)

reset_parameters()[source]
class SAGEConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, normalize: bool = False, bias: bool = True, **kwargs)[source]

The GraphSAGE operator from the “Inductive Representation Learning on Large Graphs” paper

\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W_2} \cdot \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j\]
Parameters
  • in_channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • normalize (bool, optional) – If set to True, output features will be \(\ell_2\)-normalized, i.e., \(\frac{\mathbf{x}^{\prime}_i} {\| \mathbf{x}^{\prime}_i \|_2}\). (default: False)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], size: Optional[Tuple[int, int]] = None) → torch.Tensor[source]
reset_parameters()[source]
class SGConv(in_channels: int, out_channels: int, K: int = 1, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]

The simple graph convolutional operator from the “Simplifying Graph Convolutional Networks” paper

\[\mathbf{X}^{\prime} = {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X} \mathbf{\Theta},\]

where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix.

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • K (int, optional) – Number of hops \(K\). (default: 1)

  • cached (bool, optional) – If set to True, the layer will cache the computation of \({\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: torch.Tensor, edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_weight: Optional[torch.Tensor] = None) → torch.Tensor[source]
reset_parameters()[source]
class SignedConv(in_channels: int, out_channels: int, first_aggr: bool, bias: bool = True, **kwargs)[source]

The signed graph convolutional operator from the “Signed Graph Convolutional Network” paper

\[ \begin{align}\begin{aligned}\mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w , \mathbf{x}_v \right]\\\mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{neg})} \left[ \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w , \mathbf{x}_v \right]\end{aligned}\end{align} \]

if first_aggr is set to True, and

\[ \begin{align}\begin{aligned}\mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{pos})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{neg})}, \mathbf{x}_v^{(\textrm{pos})} \right]\\\mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{neg})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{pos})}, \mathbf{x}_v^{(\textrm{neg})} \right]\end{aligned}\end{align} \]

otherwise. In case first_aggr is False, the layer expects x to be a tensor where x[:, :in_channels] denotes the positive node features \(\mathbf{X}^{(\textrm{pos})}\) and x[:, in_channels:] denotes the negative node features \(\mathbf{X}^{(\textrm{neg})}\).

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • first_aggr (bool) – Denotes which aggregation formula to use.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], pos_edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], neg_edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor])[source]
reset_parameters()[source]
class SplineConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, dim: int, kernel_size: Union[int, List[int]], is_open_spline: bool = True, degree: int = 1, aggr: str = 'mean', root_weight: bool = True, bias: bool = True, **kwargs)[source]

The spline-based convolutional operator from the “SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels” paper

\[\mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),\]

where \(h_{\mathbf{\Theta}}\) denotes a kernel function defined over the weighted B-Spline tensor product basis.

Note

Pseudo-coordinates must lay in the fixed interval \([0, 1]\) for this method to work as intended.

Parameters
  • in_channels (int or tuple) – Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • dim (int) – Pseudo-coordinate dimensionality.

  • kernel_size (int or [int]) – Size of the convolving kernel.

  • is_open_spline (bool or [bool], optional) – If set to False, the operator will use a closed B-spline basis in this dimension. (default True)

  • degree (int, optional) – B-spline basis degrees. (default: 1)

  • aggr (string, optional) – The aggregation operator to use ("add", "mean", "max"). (default: "mean")

  • root_weight (bool, optional) – If set to False, the layer will not add transformed root node features to the output. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]], edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_attr: Optional[torch.Tensor] = None, size: Optional[Tuple[int, int]] = None) → torch.Tensor[source]
reset_parameters()[source]
class TAGConv(in_channels: int, out_channels: int, K: int = 3, bias: bool = True, normalize: bool = True, **kwargs)[source]
The topology adaptive graph convolutional networks operator from the

“Topology Adaptive Graph Convolutional Networks” paper

\[\mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A}^k \mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k},\]

where \(\mathbf{A}\) denotes the adjacency matrix and \(D_{ii} = \sum_{j=0} A_{ij}\) its diagonal degree matrix.

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • K (int, optional) – Number of hops \(K\). (default: 3)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • normalize (bool, optional) – Whether to apply symmetric normalization. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: torch.Tensor, edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor], edge_weight: Optional[torch.Tensor] = None) → torch.Tensor[source]
reset_parameters()[source]
class XConv(in_channels: int, out_channels: int, dim: int, kernel_size: int, hidden_channels: Optional[int] = None, dilation: int = 1, bias: bool = True, num_workers: int = 1)[source]

The convolutional operator on \(\mathcal{X}\)-transformed points from the “PointCNN: Convolution On X-Transformed Points” paper

\[\mathbf{x}^{\prime}_i = \mathrm{Conv}\left(\mathbf{K}, \gamma_{\mathbf{\Theta}}(\mathbf{P}_i - \mathbf{p}_i) \times \left( h_\mathbf{\Theta}(\mathbf{P}_i - \mathbf{p}_i) \, \Vert \, \mathbf{x}_i \right) \right),\]

where \(\mathbf{K}\) and \(\mathbf{P}_i\) denote the trainable filter and neighboring point positions of \(\mathbf{x}_i\), respectively. \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) describe neural networks, i.e. MLPs, where \(h_{\mathbf{\Theta}}\) individually lifts each point into a higher-dimensional space, and \(\gamma_{\mathbf{\Theta}}\) computes the \(\mathcal{X}\)- transformation matrix based on all points in a neighborhood.

Parameters
  • in_channels (int) – Size of each input sample.

  • out_channels (int) – Size of each output sample.

  • dim (int) – Point cloud dimensionality.

  • kernel_size (int) – Size of the convolving kernel, i.e. number of neighbors including self-loops.

  • hidden_channels (int, optional) – Output size of \(h_{\mathbf{\Theta}}\), i.e. dimensionality of lifted points. If set to None, will be automatically set to in_channels / 4. (default: None)

  • dilation (int, optional) – The factor by which the neighborhood is extended, from which kernel_size neighbors are then uniformly sampled. Can be interpreted as the dilation rate of classical convolutional operators. (default: 1)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • num_workers (int) – Number of workers to use for k-NN computation. Has no effect in case batch is not None, or the input lies on the GPU. (default: 1)

forward(x: torch.Tensor, pos: torch.Tensor, batch: Optional[torch.Tensor] = None)[source]
reset_parameters()[source]
class MetaLayer(edge_model=None, node_model=None, global_model=None)[source]

A meta layer for building any kind of graph network, inspired by the “Relational Inductive Biases, Deep Learning, and Graph Networks” paper.

A graph network takes a graph as input and returns an updated graph as output (with same connectivity). The input graph has node features x, edge features edge_attr as well as global-level features u. The output graph has the same structure, but updated features.

Edge features, node features as well as global features are updated by calling the modules edge_model, node_model and global_model, respectively.

To allow for batch-wise graph processing, all callable functions take an additional argument batch, which determines the assignment of edges or nodes to their specific graphs.

Parameters
  • edge_model (Module, optional) – A callable which updates a graph’s edge features based on its source and target node features, its current edge features and its global features. (default: None)

  • node_model (Module, optional) – A callable which updates a graph’s node features based on its current node features, its graph connectivity, its edge features and its global features. (default: None)

  • global_model (Module, optional) – A callable which updates a graph’s global features based on its node features, its graph connectivity, its edge features and its current global features.

from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_scatter import scatter_mean
from torch_geometric.nn import MetaLayer

class EdgeModel(torch.nn.Module):
    def __init__(self):
        super(EdgeModel, self).__init__()
        self.edge_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))

    def forward(self, src, dest, edge_attr, u, batch):
        # source, target: [E, F_x], where E is the number of edges.
        # edge_attr: [E, F_e]
        # u: [B, F_u], where B is the number of graphs.
        # batch: [E] with max entry B - 1.
        out = torch.cat([src, dest, edge_attr, u[batch]], 1)
        return self.edge_mlp(out)

class NodeModel(torch.nn.Module):
    def __init__(self):
        super(NodeModel, self).__init__()
        self.node_mlp_1 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
        self.node_mlp_2 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))

    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        row, col = edge_index
        out = torch.cat([x[row], edge_attr], dim=1)
        out = self.node_mlp_1(out)
        out = scatter_mean(out, col, dim=0, dim_size=x.size(0))
        out = torch.cat([x, out, u[batch]], dim=1)
        return self.node_mlp_2(out)

class GlobalModel(torch.nn.Module):
    def __init__(self):
        super(GlobalModel, self).__init__()
        self.global_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))

    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        out = torch.cat([u, scatter_mean(x, batch, dim=0)], dim=1)
        return self.global_mlp(out)

op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())
x, edge_attr, u = op(x, edge_index, edge_attr, u, batch)
forward(x, edge_index, edge_attr=None, u=None, batch=None)[source]
reset_parameters()[source]

Dense Convolutional Layers

class DenseGCNConv(in_channels, out_channels, improved=False, bias=True)[source]

See torch_geometric.nn.conv.GCNConv.

forward(x, adj, mask=None, add_loop=True)[source]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

  • add_loop (bool, optional) – If set to False, the layer will not automatically add self-loops to the adjacency matrices. (default: True)

reset_parameters()[source]
class DenseSAGEConv(in_channels, out_channels, normalize=False, bias=True)[source]

See torch_geometric.nn.conv.SAGEConv.

forward(x, adj, mask=None)[source]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

  • add_loop (bool, optional) – If set to False, the layer will not automatically add self-loops to the adjacency matrices. (default: True)

reset_parameters()[source]
class DenseGraphConv(in_channels, out_channels, aggr='add', bias=True)[source]

See torch_geometric.nn.conv.GraphConv.

forward(x, adj, mask=None)[source]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

reset_parameters()[source]
class DenseGINConv(nn, eps=0, train_eps=False)[source]

See torch_geometric.nn.conv.GINConv.

Return type

Tensor

forward(x, adj, mask=None, add_loop=True)[source]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

  • add_loop (bool, optional) – If set to False, the layer will not automatically add self-loops to the adjacency matrices. (default: True)

reset_parameters()[source]

Normalization Layers

class BatchNorm(in_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]

Applies batch normalization over a batch of node features as described in the “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift” paper

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]

The mean and standard-deviation are calculated per-dimension over all nodes inside the mini-batch.

Parameters
  • in_channels (int) – Size of each input sample.

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • momentum (float, optional) – The value used for the running mean and running variance computation. (default: 0.1)

  • affine (bool, optional) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: True)

  • track_running_stats (bool, optional) – If set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: True)

forward(x: torch.Tensor) → torch.Tensor[source]
reset_parameters()[source]
class GraphSizeNorm[source]

Applies Graph Size Normalization over each individual graph in a batch of node features as described in the “Benchmarking Graph Neural Networks” paper

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x}_i}{\sqrt{|\mathcal{V}|}}\]
forward(x, batch=None)[source]
class InstanceNorm(in_channels, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)[source]

Applies instance normalization over each individual example in a batch of node features as described in the “Instance Normalization: The Missing Ingredient for Fast Stylization” paper

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]

The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch.

Parameters
  • in_channels (int) – Size of each input sample.

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • momentum (float, optional) – The value used for the running mean and running variance computation. (default: 0.1)

  • affine (bool, optional) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: False)

  • track_running_stats (bool, optional) – If set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses instance statistics in both training and eval modes. (default: False)

forward(x: torch.Tensor, batch: Optional[torch.Tensor] = None) → torch.Tensor[source]
class LayerNorm(in_channels, eps=1e-05, affine=True)[source]

Applies layer normalization over each individual example in a batch of node features as described in the “Instance Normalization: The Missing Ingredient for Fast Stylization” paper

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]

The mean and standard-deviation are calculated across all nodes and all node channels separately for each object in a mini-batch.

Parameters
  • in_channels (int) – Size of each input sample.

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • affine (bool, optional) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: True)

forward(x: torch.Tensor, batch: Optional[torch.Tensor] = None) → torch.Tensor[source]
reset_parameters()[source]
class MessageNorm(learn_scale: bool = False)[source]

Applies message normalization over the aggregated messages as described in the “DeeperGCNs: All You Need to Train Deeper GCNs” paper

\[\mathbf{x}_i^{\prime} = \mathrm{MLP} \left( \mathbf{x}_{i} + s \cdot {\| \mathbf{x}_i \|}_2 \cdot \frac{\mathbf{m}_{i}}{{\|\mathbf{m}_i\|}_2} \right)\]
Parameters

learn_scale (bool, optional) – If set to True, will learn the scaling factor \(s\) of message normalization. (default: False)

forward(x: torch.Tensor, msg: torch.Tensor, p: int = 2)[source]
reset_parameters()[source]
class PairNorm(scale: float = 1.0, scale_individually: bool = False, eps: float = 1e-05)[source]

Applies pair normalization over node features as described in the “PairNorm: Tackling Oversmoothing in GNNs” paper

\[ \begin{align}\begin{aligned}\begin{split}\mathbf{x}_i^c &= \mathbf{x}_i - \frac{1}{n} \sum_{i=1}^n \mathbf{x}_i \\\end{split}\\\mathbf{x}_i^{\prime} &= s \cdot \frac{\mathbf{x}_i^c}{\sqrt{\frac{1}{n} \sum_{i=1}^n {\| \mathbf{x}_i^c \|}^2_2}}\end{aligned}\end{align} \]
Parameters
  • scale (float, optional) – Scaling factor \(s\) of normalization. (default, 1.)

  • scale_individually (bool, optional) – If set to True, will compute the scaling step as \(\mathbf{x}^{\prime}_i = s \cdot \frac{\mathbf{x}_i^c}{{\| \mathbf{x}_i^c \|}_2}\). (default: False)

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

forward(x: torch.Tensor, batch: Optional[torch.Tensor] = None) → torch.Tensor[source]

Global Pooling Layers

class GlobalAttention(gate_nn, nn=None)[source]

Global soft attention layer from the “Gated Graph Sequence Neural Networks” paper

\[\mathbf{r}_i = \sum_{n=1}^{N_i} \mathrm{softmax} \left( h_{\mathrm{gate}} ( \mathbf{x}_n ) \right) \odot h_{\mathbf{\Theta}} ( \mathbf{x}_n ),\]

where \(h_{\mathrm{gate}} \colon \mathbb{R}^F \to \mathbb{R}\) and \(h_{\mathbf{\Theta}}\) denote neural networks, i.e. MLPS.

Parameters
  • gate_nn (torch.nn.Module) – A neural network \(h_{\mathrm{gate}}\) that computes attention scores by mapping node features x of shape [-1, in_channels] to shape [-1, 1], e.g., defined by torch.nn.Sequential.

  • nn (torch.nn.Module, optional) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features x of shape [-1, in_channels] to shape [-1, out_channels] before combining them with the attention scores, e.g., defined by torch.nn.Sequential. (default: None)

forward(x, batch, size=None)[source]
reset_parameters()[source]
class Set2Set(in_channels, processing_steps, num_layers=1)[source]

The global pooling operator based on iterative content-based attention from the “Order Matters: Sequence to sequence for sets” paper

\[ \begin{align}\begin{aligned}\mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1})\\\alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t)\\\mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i\\\mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t,\end{aligned}\end{align} \]

where \(\mathbf{q}^{*}_T\) defines the output of the layer with twice the dimensionality as the input.

Parameters
  • in_channels (int) – Size of each input sample.

  • processing_steps (int) – Number of iterations \(T\).

  • num_layers (int, optional) – Number of recurrent layers, .e.g, setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. (default: 1)

forward(x, batch)[source]
reset_parameters()[source]
global_add_pool(x, batch, size: Optional[int] = None)[source]

Returns batch-wise graph-level-outputs by adding node features across the node dimension, so that for a single graph \(\mathcal{G}_i\) its output is computed by

\[\mathbf{r}_i = \sum_{n=1}^{N_i} \mathbf{x}_n\]
Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (int, optional) – Batch-size \(B\). Automatically calculated if not given. (default: None)

Return type

Tensor

global_max_pool(x, batch, size: Optional[int] = None)[source]

Returns batch-wise graph-level-outputs by taking the channel-wise maximum across the node dimension, so that for a single graph \(\mathcal{G}_i\) its output is computed by

\[\mathbf{r}_i = \mathrm{max}_{n=1}^{N_i} \, \mathbf{x}_n\]
Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (int, optional) – Batch-size \(B\). Automatically calculated if not given. (default: None)

Return type

Tensor

global_mean_pool(x, batch, size: Optional[int] = None)[source]

Returns batch-wise graph-level-outputs by averaging node features across the node dimension, so that for a single graph \(\mathcal{G}_i\) its output is computed by

\[\mathbf{r}_i = \frac{1}{N_i} \sum_{n=1}^{N_i} \mathbf{x}_n\]
Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (int, optional) – Batch-size \(B\). Automatically calculated if not given. (default: None)

Return type

Tensor

global_sort_pool(x, batch, k)[source]

The global pooling operator from the “An End-to-End Deep Learning Architecture for Graph Classification” paper, where node features are sorted in descending order based on their last feature channel. The first \(k\) nodes form the output of the layer.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • k (int) – The number of nodes to hold for each graph.

Return type

Tensor

Pooling Layers

class ASAPooling(in_channels, ratio=0.5, GNN=None, dropout=0, negative_slope=0.2, add_self_loops=False, **kwargs)[source]

The Adaptive Structure Aware Pooling operator from the “ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations” paper.

Parameters
  • in_channels (int) – Size of each input sample.

  • ratio (float, optional) – Graph pooling ratio, which is used to compute \(k = \lceil \mathrm{ratio} \cdot N \rceil\). (default: 0.5)

  • GNN (torch.nn.Module, optional) – A graph neural network layer for using intra-cluster properties. Especially helpful for graphs with higher degree of neighborhood (one of torch_geometric.nn.conv.GraphConv, torch_geometric.nn.conv.GCNConv or any GNN which supports the edge_weight parameter). (default: None)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • add_self_loops (bool, optional) – If set to True, will add self loops to the new graph connectivity. (default: False)

  • **kwargs (optional) – Additional parameters for initializing the graph neural network layer.

forward(x, edge_index, edge_weight=None, batch=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()[source]
class EdgePooling(in_channels, edge_score_method=None, dropout=0, add_to_edge_score=0.5)[source]

The edge pooling operator from the “Towards Graph Pooling by Edge Contraction” and “Edge Contraction Pooling for Graph Neural Networks” papers.

In short, a score is computed for each edge. Edges are contracted iteratively according to that score unless one of their nodes has already been part of a contracted edge.

To duplicate the configuration from the “Towards Graph Pooling by Edge Contraction” paper, use either EdgePooling.compute_edge_score_softmax() or EdgePooling.compute_edge_score_tanh(), and set add_to_edge_score to 0.

To duplicate the configuration from the “Edge Contraction Pooling for Graph Neural Networks” paper, set dropout to 0.2.

Parameters
  • in_channels (int) – Size of each input sample.

  • edge_score_method (function, optional) – The function to apply to compute the edge score from raw edge scores. By default, this is the softmax over all incoming edges for each node. This function takes in a raw_edge_score tensor of shape [num_nodes], an edge_index tensor and the number of nodes num_nodes, and produces a new tensor of the same size as raw_edge_score describing normalized edge scores. Included functions are EdgePooling.compute_edge_score_softmax(), EdgePooling.compute_edge_score_tanh(), and EdgePooling.compute_edge_score_sigmoid(). (default: EdgePooling.compute_edge_score_softmax())

  • dropout (float, optional) – The probability with which to drop edge scores during training. (default: 0)

  • add_to_edge_score (float, optional) – This is added to each computed edge score. Adding this greatly helps with unpool stability. (default: 0.5)

static compute_edge_score_sigmoid(raw_edge_score, edge_index, num_nodes)[source]
static compute_edge_score_softmax(raw_edge_score, edge_index, num_nodes)[source]
static compute_edge_score_tanh(raw_edge_score, edge_index, num_nodes)[source]
forward(x, edge_index, batch)[source]

Forward computation which computes the raw edge score, normalizes it, and merges the edges.

Parameters
  • x (Tensor) – The node features.

  • edge_index (LongTensor) – The edge indices.

  • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

Return types:
  • x (Tensor) - The pooled node features.

  • edge_index (LongTensor) - The coarsened edge indices.

  • batch (LongTensor) - The coarsened batch vector.

  • unpool_info (unpool_description) - Information that is consumed by EdgePooling.unpool() for unpooling.

reset_parameters()[source]
unpool(x, unpool_info)[source]

Unpools a previous edge pooling step.

For unpooling, x should be of same shape as those produced by this layer’s forward() function. Then, it will produce an unpooled x in addition to edge_index and batch.

Parameters
  • x (Tensor) – The node features.

  • unpool_info (unpool_description) – Information that has been produced by EdgePooling.forward().

Return types:
  • x (Tensor) - The unpooled node features.

  • edge_index (LongTensor) - The new edge indices.

  • batch (LongTensor) - The new batch vector.

unpool_description

alias of UnpoolDescription

class SAGPooling(in_channels, ratio=0.5, GNN=<class 'torch_geometric.nn.conv.graph_conv.GraphConv'>, min_score=None, multiplier=1, nonlinearity=<built-in method tanh of type object>, **kwargs)[source]

The self-attention pooling operator from the “Self-Attention Graph Pooling” and “Understanding Attention and Generalization in Graph Neural Networks” papers

if min_score \(\tilde{\alpha}\) is None:

\[ \begin{align}\begin{aligned}\mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A})\\\mathbf{i} &= \mathrm{top}_k(\mathbf{y})\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}\end{aligned}\end{align} \]

if min_score \(\tilde{\alpha}\) is a value in [0, 1]:

\[ \begin{align}\begin{aligned}\mathbf{y} &= \mathrm{softmax}(\textrm{GNN}(\mathbf{X},\mathbf{A}))\\\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},\end{aligned}\end{align} \]

where nodes are dropped based on a learnable projection score \(\mathbf{p}\). Projections scores are learned based on a graph neural network layer.

Parameters
  • in_channels (int) – Size of each input sample.

  • ratio (float) – Graph pooling ratio, which is used to compute \(k = \lceil \mathrm{ratio} \cdot N \rceil\). This value is ignored if min_score is not None. (default: 0.5)

  • GNN (torch.nn.Module, optional) – A graph neural network layer for calculating projection scores (one of torch_geometric.nn.conv.GraphConv, torch_geometric.nn.conv.GCNConv, torch_geometric.nn.conv.GATConv or torch_geometric.nn.conv.SAGEConv). (default: torch_geometric.nn.conv.GraphConv)

  • min_score (float, optional) – Minimal node score \(\tilde{\alpha}\) which is used to compute indices of pooled nodes \(\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}\). When this value is not None, the ratio argument is ignored. (default: None)

  • multiplier (float, optional) – Coefficient by which features gets multiplied after pooling. This can be useful for large graphs and when min_score is used. (default: 1)

  • nonlinearity (torch.nn.functional, optional) – The nonlinearity to use. (default: torch.tanh)

  • **kwargs (optional) – Additional parameters for initializing the graph neural network layer.

forward(x, edge_index, edge_attr=None, batch=None, attn=None)[source]
reset_parameters()[source]
class TopKPooling(in_channels, ratio=0.5, min_score=None, multiplier=1, nonlinearity=<built-in method tanh of type object>)[source]

\(\mathrm{top}_k\) pooling operator from the “Graph U-Nets”, “Towards Sparse Hierarchical Graph Classifiers” and “Understanding Attention and Generalization in Graph Neural Networks” papers

if min_score \(\tilde{\alpha}\) is None:

\[ \begin{align}\begin{aligned}\mathbf{y} &= \frac{\mathbf{X}\mathbf{p}}{\| \mathbf{p} \|}\\\mathbf{i} &= \mathrm{top}_k(\mathbf{y})\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}\end{aligned}\end{align} \]

if min_score \(\tilde{\alpha}\) is a value in [0, 1]:

\[ \begin{align}\begin{aligned}\mathbf{y} &= \mathrm{softmax}(\mathbf{X}\mathbf{p})\\\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},\end{aligned}\end{align} \]

where nodes are dropped based on a learnable projection score \(\mathbf{p}\).

Parameters
  • in_channels (int) – Size of each input sample.

  • ratio (float) – Graph pooling ratio, which is used to compute \(k = \lceil \mathrm{ratio} \cdot N \rceil\). This value is ignored if min_score is not None. (default: 0.5)

  • min_score (float, optional) – Minimal node score \(\tilde{\alpha}\) which is used to compute indices of pooled nodes \(\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}\). When this value is not None, the ratio argument is ignored. (default: None)

  • multiplier (float, optional) – Coefficient by which features gets multiplied after pooling. This can be useful for large graphs and when min_score is used. (default: 1)

  • nonlinearity (torch.nn.functional, optional) – The nonlinearity to use. (default: torch.tanh)

forward(x, edge_index, edge_attr=None, batch=None, attn=None)[source]
reset_parameters()[source]
avg_pool(cluster, data, transform=None)[source]

Pools and coarsens a graph given by the torch_geometric.data.Data object according to the clustering defined in cluster. Final node features are defined by the average features of all nodes within the same cluster. See torch_geometric.nn.pool.max_pool() for more details.

Parameters
  • cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

  • data (Data) – Graph data object.

  • transform (callable, optional) – A function/transform that takes in the coarsened and pooled torch_geometric.data.Data object and returns a transformed version. (default: None)

Return type

torch_geometric.data.Data

avg_pool_neighbor_x(data, flow='source_to_target')[source]

Average pools neighboring node features, where each feature in data.x is replaced by the average feature values from the central node and its neighbors.

avg_pool_x(cluster, x, batch, size: Optional[int] = None)[source]

Average pools node features according to the clustering defined in cluster. See torch_geometric.nn.pool.max_pool_x() for more details.

Parameters
  • cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (int, optional) – The maximum number of clusters in a single example. (default: None)

Return type

(Tensor, LongTensor) if size is None, else Tensor

fps(x, batch=None, ratio=0.5, random_start=True)[source]

A sampling algorithm from the “PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space” paper, which iteratively samples the most distant point with regard to the rest points.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • ratio (float, optional) – Sampling ratio. (default: 0.5)

  • random_start (bool, optional) – If set to False, use the first node in \(\mathbf{X}\) as starting node. (default: obj:True)

Return type

LongTensor

import torch
from torch_geometric.nn import fps

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch = torch.tensor([0, 0, 0, 0])
index = fps(x, batch, ratio=0.5)
graclus(edge_index, weight: Optional[torch.Tensor] = None, num_nodes: Optional[int] = None)[source]

A greedy clustering algorithm from the “Weighted Graph Cuts without Eigenvectors: A Multilevel Approach” paper of picking an unmarked vertex and matching it with one of its unmarked neighbors (that maximizes its edge weight). The GPU algoithm is adapted from the “A GPU Algorithm for Greedy Graph Matching” paper.

Parameters
  • edge_index (LongTensor) – The edge indices.

  • weight (Tensor, optional) – One-dimensional edge weights. (default: None)

  • num_nodes (int, optional) – The number of nodes, i.e. max_val + 1 of edge_index. (default: None)

Return type

LongTensor

knn(x, y, k, batch_x=None, batch_y=None, cosine=False, num_workers=1)[source]

Finds for each element in y the k nearest points in x.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • y (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{M \times F}\).

  • k (int) – The number of neighbors.

  • batch_x (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • batch_y (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M\), which assigns each node to a specific example. (default: None)

  • cosine (boolean, optional) – If True, will use the cosine distance instead of euclidean distance to find nearest neighbors. (default: False)

  • num_workers (int) – Number of workers to use for computation. Has no effect in case batch_x or batch_y is not None, or the input lies on the GPU. (default: 1)

Return type

LongTensor

import torch
from torch_geometric.nn import knn

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
batch_x = torch.tensor([0, 0])
assign_index = knn(x, y, 2, batch_x, batch_y)
knn_graph(x, k, batch=None, loop=False, flow='source_to_target', cosine=False, num_workers=1)[source]

Computes graph edges to the nearest k points.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • k (int) – The number of neighbors.

  • batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • loop (bool, optional) – If True, the graph will contain self-loops. (default: False)

  • flow (string, optional) – The flow direction when using in combination with message passing ("source_to_target" or "target_to_source"). (default: "source_to_target")

  • cosine (boolean, optional) – If True, will use the cosine distance instead of euclidean distance to find nearest neighbors. (default: False)

  • num_workers (int) – Number of workers to use for computation. Has no effect in case batch is not None, or the input lies on the GPU. (default: 1)

Return type

LongTensor

import torch
from torch_geometric.nn import knn_graph

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch = torch.tensor([0, 0, 0, 0])
edge_index = knn_graph(x, k=2, batch=batch, loop=False)
max_pool(cluster, data, transform=None)[source]

Pools and coarsens a graph given by the torch_geometric.data.Data object according to the clustering defined in cluster. All nodes within the same cluster will be represented as one node. Final node features are defined by the maximum features of all nodes within the same cluster, node positions are averaged and edge indices are defined to be the union of the edge indices of all nodes within the same cluster.

Parameters
  • cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

  • data (Data) – Graph data object.

  • transform (callable, optional) – A function/transform that takes in the coarsened and pooled torch_geometric.data.Data object and returns a transformed version. (default: None)

Return type

torch_geometric.data.Data

max_pool_neighbor_x(data, flow='source_to_target')[source]

Max pools neighboring node features, where each feature in data.x is replaced by the feature value with the maximum value from the central node and its neighbors.

max_pool_x(cluster, x, batch, size: Optional[int] = None)[source]

Max-Pools node features according to the clustering defined in cluster.

Parameters
  • cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (int, optional) – The maximum number of clusters in a single example. This property is useful to obtain a batch-wise dense representation, e.g. for applying FC layers, but should only be used if the size of the maximum number of clusters per example is known in advance. (default: None)

Return type

(Tensor, LongTensor) if size is None, else Tensor

nearest(x, y, batch_x=None, batch_y=None)[source]

Clusters points in x together which are nearest to a given query point in y.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • y (Tensor) – Node feature matrix \(\mathbf{Y} \in \mathbb{R}^{M \times F}\).

  • batch_x (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • batch_y (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M\), which assigns each node to a specific example. (default: None)

Return type

LongTensor

import torch
from torch_geometric.nn import nearest

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
batch_y = torch.tensor([0, 0])
cluster = nearest(x, y, batch_x, batch_y)
radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32, num_workers=1)[source]

Finds for each element in y all points in x within distance r.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • y (Tensor) – Node feature matrix \(\mathbf{Y} \in \mathbb{R}^{M \times F}\).

  • r (float) – The radius.

  • batch_x (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • batch_y (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M\), which assigns each node to a specific example. (default: None)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to return for each element in y. (default: 32)

  • num_workers (int) – Number of workers to use for computation. Has no effect in case batch_x or batch_y is not None, or the input lies on the GPU. (default: 1)

Return type

LongTensor

import torch
from torch_geometric.nn import radius

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
batch_y = torch.tensor([0, 0])
assign_index = radius(x, y, 1.5, batch_x, batch_y)
radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32, flow='source_to_target', num_workers=1)[source]

Computes graph edges to all points within a given distance.

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • r (float) – The radius.

  • batch (LongTensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • loop (bool, optional) – If True, the graph will contain self-loops. (default: False)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to return for each element in y. (default: 32)

  • flow (string, optional) – The flow direction when using in combination with message passing ("source_to_target" or "target_to_source"). (default: "source_to_target")

  • num_workers (int) – Number of workers to use for computation. Has no effect in case batch is not None, or the input lies on the GPU. (default: 1)

Return type

LongTensor

import torch
from torch_geometric.nn import radius_graph

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch = torch.tensor([0, 0, 0, 0])
edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
voxel_grid(pos, batch, size, start=None, end=None)[source]

Voxel grid pooling from the, e.g., Dynamic Edge-Conditioned Filters in Convolutional Networks on Graphs paper, which overlays a regular grid of user-defined size over a point cloud and clusters all points within the same voxel.

Parameters
  • pos (Tensor) – Node position matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times D}\).

  • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (float or [float] or Tensor) – Size of a voxel (in each dimension).

  • start (float or [float] or Tensor, optional) – Start coordinates of the grid (in each dimension). If set to None, will be set to the minimum coordinates found in pos. (default: None)

  • end (float or [float] or Tensor, optional) – End coordinates of the grid (in each dimension). If set to None, will be set to the maximum coordinates found in pos. (default: None)

Return type

LongTensor

Dense Pooling Layers

dense_diff_pool(x, adj, s, mask=None)[source]

Differentiable pooling operator from the “Hierarchical Graph Representation Learning with Differentiable Pooling” paper

\[ \begin{align}\begin{aligned}\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X}\\\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})\end{aligned}\end{align} \]

based on dense learned assignments \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\). Returns pooled node feature matrix, coarsened adjacency matrix and two auxiliary objectives: (1) The link prediction loss

\[\mathcal{L}_{LP} = {\| \mathbf{A} - \mathrm{softmax}(\mathbf{S}) {\mathrm{softmax}(\mathbf{S})}^{\top} \|}_F,\]

and the entropy regularization

\[\mathcal{L}_E = \frac{1}{N} \sum_{n=1}^N H(\mathbf{S}_n).\]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\) with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\).

  • s (Tensor) – Assignment tensor \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\) with number of clusters \(C\). The softmax does not have to be applied beforehand, since it is executed within this method.

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

Return type

(Tensor, Tensor, Tensor, Tensor)

dense_mincut_pool(x, adj, s, mask=None)[source]

MinCUt pooling operator from the “Mincut Pooling in Graph Neural Networks” paper

\[ \begin{align}\begin{aligned}\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X}\\\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})\end{aligned}\end{align} \]

based on dense learned assignments \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\). Returns pooled node feature matrix, coarsened symmetrically normalized adjacency matrix and two auxiliary objectives: (1) The minCUT loss

\[\mathcal{L}_c = - \frac{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{A} \mathbf{S})} {\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{D} \mathbf{S})}\]

where \(\mathbf{D}\) is the degree matrix, and (2) the orthogonality loss

\[\mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}} {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}} \right\|}_F.\]
Parameters
  • x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\) with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (Tensor) – Symmetrically normalized adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\).

  • s (Tensor) – Assignment tensor \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\) with number of clusters \(C\). The softmax does not have to be applied beforehand, since it is executed within this method.

  • mask (BoolTensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

Return type

(Tensor, Tensor, Tensor, Tensor)

Unpooling Layers

knn_interpolate(x, pos_x, pos_y, batch_x=None, batch_y=None, k=3, num_workers=1)[source]

The k-NN interpolation from the “PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space” paper. For each point \(y\) with position \(\mathbf{p}(y)\), its interpolated features \(\mathbf{f}(y)\) are given by

\[\mathbf{f}(y) = \frac{\sum_{i=1}^k w(x_i) \mathbf{f}(x_i)}{\sum_{i=1}^k w(x_i)} \textrm{, where } w(x_i) = \frac{1}{d(\mathbf{p}(y), \mathbf{p}(x_i))^2}\]

and \(\{ x_1, \ldots, x_k \}\) denoting the \(k\) nearest points to \(y\).

Parameters
  • x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • pos_x (Tensor) – Node position matrix \(\in \mathbb{R}^{N \times d}\).

  • pos_y (Tensor) – Upsampled node position matrix \(\in \mathbb{R}^{M \times d}\).

  • batch_x (LongTensor, optional) – Batch vector \(\mathbf{b_x} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node from \(\mathbf{X}\) to a specific example. (default: None)

  • batch_y (LongTensor, optional) – Batch vector \(\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node from \(\mathbf{Y}\) to a specific example. (default: None)

  • k (int, optional) – Number of neighbors. (default: 3)

  • num_workers (int) – Number of workers to use for computation. Has no effect in case batch_x or batch_y is not None, or the input lies on the GPU. (default: 1)

Models

class ARGA(encoder, discriminator, decoder=None)[source]

The Adversarially Regularized Graph Auto-Encoder model from the “Adversarially Regularized Graph Autoencoder for Graph Embedding” paper. paper.

Parameters
discriminator_loss(z)[source]

Computes the loss of the discriminator.

Parameters

z (Tensor) – The latent space \(\mathbf{Z}\).

reg_loss(z)[source]

Computes the regularization loss of the encoder.

Parameters

z (Tensor) – The latent space \(\mathbf{Z}\).

reset_parameters()[source]
class ARGVA(encoder, discriminator, decoder=None)[source]

The Adversarially Regularized Variational Graph Auto-Encoder model from the “Adversarially Regularized Graph Autoencoder for Graph Embedding” paper. paper.

Parameters
  • encoder (Module) – The encoder module to compute \(\mu\) and \(\log\sigma^2\).

  • discriminator (Module) – The discriminator module.

  • decoder (Module, optional) – The decoder module. If set to None, will default to the torch_geometric.nn.models.InnerProductDecoder. (default: None)

encode(*args, **kwargs)[source]
kl_loss(mu=None, logstd=None)[source]
reparametrize(mu, logstd)[source]
class DeepGCNLayer(conv=None, norm=None, act=None, block='res+', dropout=0.0, ckpt_grad=False)[source]

The skip connection operations from the “DeepGCNs: Can GCNs Go as Deep as CNNs?” and “All You Need to Train Deeper GCNs” papers. The implemented skip connections includes the pre-activation residual connection ("res+"), the residual connection ("res"), the dense connection ("dense") and no connections ("plain").

  • Res+ ("res+"):

\[\text{Normalization}\to\text{Activation}\to\text{Dropout}\to \text{GraphConv}\to\text{Res}\]
  • Res ("res") / Dense ("dense") / Plain ("plain"):

\[\text{GraphConv}\to\text{Normalization}\to\text{Activation}\to \text{Res/Dense/Plain}\to\text{Dropout}\]

Note

For an example of using GENConv, see examples/ogbn_proteins_deepgcn.py.

Parameters
  • conv (torch.nn.Module, optional) – the GCN operator. (default: None)

  • norm (torch.nn.Module) – the normalization layer. (default: None)

  • act (torch.nn.Module) – the activation layer. (default: None)

  • block (string, optional) – The skip connection operation to use ("res+", "res", "dense" or "plain"). (default: "res+")

  • dropout – (float, optional): Whether to apply or dropout. (default: 0.)

  • ckpt_grad (bool, optional) – If set to True, will checkpoint this part of the model. Checkpointing works by trading compute for memory, since intermediate activations do not need to be kept in memory. Set this to True in case you encounter out-of-memory errors while going deep. (default: False)

forward(*args, **kwargs)[source]
reset_parameters()[source]
class DeepGraphInfomax(hidden_channels, encoder, summary, corruption)[source]

The Deep Graph Infomax model from the “Deep Graph Infomax” paper based on user-defined encoder and summary model \(\mathcal{E}\) and \(\mathcal{R}\) respectively, and a corruption function \(\mathcal{C}\).

Parameters
  • hidden_channels (int) – The latent space dimensionality.

  • encoder (Module) – The encoder module \(\mathcal{E}\).

  • summary (callable) – The readout function \(\mathcal{R}\).

  • corruption (callable) – The corruption function \(\mathcal{C}\).

discriminate(z, summary, sigmoid=True)[source]

Given the patch-summary pair z and summary, computes the probability scores assigned to this patch-summary pair.

Parameters
  • z (Tensor) – The latent space.

  • sigmoid (bool, optional) – If set to False, does not apply the logistic sigmoid function to the output. (default: True)

forward(*args, **kwargs)[source]

Returns the latent space for the input arguments, their corruptions and their summary representation.

loss(pos_z, neg_z, summary)[source]

Computes the mutal information maximization objective.

reset_parameters()[source]
test(train_z, train_y, test_z, test_y, solver='lbfgs', multi_class='auto', *args, **kwargs)[source]

Evaluates latent space quality via a logistic regression downstream task.

class DimeNet(hidden_channels, out_channels, num_blocks, num_bilinear, num_spherical, num_radial, cutoff=5.0, envelope_exponent=5, num_before_skip=1, num_after_skip=2, num_output_layers=3, act=<function swish>)[source]

The directional message passing neural network (DimeNet) from the “Directional Message Passing for Molecular Graphs” paper. DimeNet transforms messages based on the angle between them in a rotation-equivariant fashion.

Note

For an example of using a pretrained DimeNet variant, see examples/qm9_pretrained_dimenet.py.

Parameters
  • hidden_channels (int) – Hidden embedding size.

  • out_channels (int) – Size of each output sample.

  • num_blocks (int) – Number of building blocks.

  • num_bilinear (int) – Size of the bilinear layer tensor.

  • num_spherical (int) – Number of spherical harmonics.

  • num_radial (int) – Number of radial basis functions.

  • cutoff – (float, optional): Cutoff distance for interatomic interactions. (default: 5.0)

  • envelope_exponent (int, optional) – Shape of the smooth cutoff. (default: 5)

  • num_before_skip – (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: 1)

  • num_after_skip – (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default: 2)

  • num_output_layers – (int, optional): Number of linear layers for the output blocks. (default: 3)

  • act – (function, optional): The activation funtion. (default: swish)

forward(z, pos, batch=None)[source]
static from_qm9_pretrained(root, dataset, target)[source]
reset_parameters()[source]
triplets(edge_index, num_nodes)[source]
url = 'https://github.com/klicperajo/dimenet/raw/master/pretrained'
class GAE(encoder, decoder=None)[source]

The Graph Auto-Encoder model from the “Variational Graph Auto-Encoders” paper based on user-defined encoder and decoder models.

Parameters
decode(*args, **kwargs)[source]

Runs the decoder and computes edge probabilities.

encode(*args, **kwargs)[source]

Runs the encoder and computes node-wise latent variables.

recon_loss(z, pos_edge_index)[source]

Given latent variables z, computes the binary cross entropy loss for positive edges pos_edge_index and negative sampled edges.

Parameters
  • z (Tensor) – The latent space \(\mathbf{Z}\).

  • pos_edge_index (LongTensor) – The positive edges to train against.

reset_parameters()[source]
test(z, pos_edge_index, neg_edge_index)[source]

Given latent variables z, positive edges pos_edge_index and negative edges neg_edge_index, computes area under the ROC curve (AUC) and average precision (AP) scores.

Parameters
  • z (Tensor) – The latent space \(\mathbf{Z}\).

  • pos_edge_index (LongTensor) – The positive edges to evaluate against.

  • neg_edge_index (LongTensor) – The negative edges to evaluate against.

class GNNExplainer(model, epochs: int = 100, lr: float = 0.01, num_hops: Optional[int] = None, log: bool = True)[source]

The GNN-Explainer model from the “GNNExplainer: Generating Explanations for Graph Neural Networks” paper for identifying compact subgraph structures and small subsets node features that play a crucial role in a GNN’s node-predictions.

Note

For an example of using GNN-Explainer, see examples/gnn_explainer.py.

Parameters
  • model (torch.nn.Module) – The GNN module to explain.

  • epochs (int, optional) – The number of epochs to train. (default: 100)

  • lr (float, optional) – The learning rate to apply. (default: 0.01)

  • num_hops (int, optional) – The number of hops the model is aggregating information from. If set to None, will automatically try to detect this information based on the number of MessagePassing layers inside model. (default: None)

  • log (bool, optional) – If set to False, will not log any learning progress. (default: True)

coeffs = {'edge_ent': 1.0, 'edge_size': 0.005, 'node_feat_ent': 0.1, 'node_feat_size': 1.0}
explain_node(node_idx, x, edge_index, **kwargs)[source]

Learns and returns a node feature mask and an edge mask that play a crucial role to explain the prediction made by the GNN for node node_idx.

Parameters
  • node_idx (int) – The node to explain.

  • x (Tensor) – The node feature matrix.

  • edge_index (LongTensor) – The edge indices.

  • **kwargs (optional) – Additional arguments passed to the GNN module.

Return type

(Tensor, Tensor)

property num_hops
visualize_subgraph(node_idx, edge_index, edge_mask, y=None, threshold=None, **kwargs)[source]

Visualizes the subgraph around node_idx given an edge mask edge_mask.

Parameters
  • node_idx (int) – The node id to explain.

  • edge_index (LongTensor) – The edge indices.

  • edge_mask (Tensor) – The edge mask.

  • y (Tensor, optional) – The ground-truth node-prediction labels used as node colorings. (default: None)

  • threshold (float, optional) – Sets a threshold for visualizing important edges. If set to None, will visualize all edges with transparancy indicating the importance of edges. (default: None)

  • **kwargs (optional) – Additional arguments passed to nx.draw().

Return type

matplotlib.axes.Axes, networkx.DiGraph

class GraphUNet(in_channels, hidden_channels, out_channels, depth, pool_ratios=0.5, sum_res=True, act=<function relu>)[source]

The Graph U-Net model from the “Graph U-Nets” paper which implements a U-Net like architecture with graph pooling and unpooling operations.

Parameters
  • in_channels (int) – Size of each input sample.

  • hidden_channels (int) – Size of each hidden sample.

  • out_channels (int) – Size of each output sample.

  • depth (int) – The depth of the U-Net architecture.

  • pool_ratios (float or [float], optional) – Graph pooling ratio for each depth. (default: 0.5)

  • sum_res (bool, optional) – If set to False, will use concatenation for integration of skip connections instead summation. (default: True)

  • act (torch.nn.functional, optional) – The nonlinearity to use. (default: torch.nn.functional.relu)

augment_adj(edge_index, edge_weight, num_nodes)[source]
forward(x, edge_index, batch=None)[source]
reset_parameters()[source]
class InnerProductDecoder[source]

The inner product decoder from the “Variational Graph Auto-Encoders” paper

\[\sigma(\mathbf{Z}\mathbf{Z}^{\top})\]

where \(\mathbf{Z} \in \mathbb{R}^{N \times d}\) denotes the latent space produced by the encoder.

forward(z, edge_index, sigmoid=True)[source]

Decodes the latent variables z into edge probabilities for the given node-pairs edge_index.

Parameters
  • z (Tensor) – The latent space \(\mathbf{Z}\).

  • sigmoid (bool, optional) – If set to False, does not apply the logistic sigmoid function to the output. (default: True)

forward_all(z, sigmoid=True)[source]

Decodes the latent variables z into a probabilistic dense adjacency matrix.

Parameters
  • z (Tensor) – The latent space \(\mathbf{Z}\).

  • sigmoid (bool, optional) – If set to False, does not apply the logistic sigmoid function to the output. (default: True)

class JumpingKnowledge(mode, channels=None, num_layers=None)[source]

The Jumping Knowledge layer aggregation module from the “Representation Learning on Graphs with Jumping Knowledge Networks” paper based on either concatenation ("cat")

\[\mathbf{x}_v^{(1)} \, \Vert \, \ldots \, \Vert \, \mathbf{x}_v^{(T)}\]

max pooling ("max")

\[\max \left( \mathbf{x}_v^{(1)}, \ldots, \mathbf{x}_v^{(T)} \right)\]

or weighted summation

\[\sum_{t=1}^T \alpha_v^{(t)} \mathbf{x}_v^{(t)}\]

with attention scores \(\alpha_v^{(t)}\) obtained from a bi-directional LSTM ("lstm").

Parameters
  • mode (string) – The aggregation scheme to use ("cat", "max" or "lstm").

  • channels (int, optional) – The number of channels per representation. Needs to be only set for LSTM-style aggregation. (default: None)

  • num_layers (int, optional) – The number of layers to aggregate. Needs to be only set for LSTM-style aggregation. (default: None)

forward(xs)[source]

Aggregates representations across different layers.

Parameters

xs (list or tuple) – List containing layer-wise representations.

reset_parameters()[source]
class MetaPath2Vec(edge_index_dict, embedding_dim, metapath, walk_length, context_size, walks_per_node=1, num_negative_samples=1, num_nodes_dict=None, sparse=False)[source]

The MetaPath2Vec model from the “metapath2vec: Scalable Representation Learning for Heterogeneous Networks” paper where random walks based on a given metapath are sampled in a heterogeneous graph, and node embeddings are learned via negative sampling optimization.

Note

For an example of using MetaPath2Vec, see examples/metapath2vec.py.

Parameters
  • edge_index_dict (dict) – Dictionary holding edge indices for each (source_node_type, relation_type, target_node_type) present in the heterogeneous graph.

  • embedding_dim (int) – The size of each embedding vector.

  • metapath (list) – The metapath described as a list of (source_node_type, relation_type, target_node_type) tuples.

  • walk_length (int) – The walk length.

  • context_size (int) – The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes.

  • walks_per_node (int, optional) – The number of walks to sample for each node. (default: 1)

  • num_negative_samples (int, optional) – The number of negative samples to use for each positive sample. (default: 1)

  • num_nodes_dict (dict, optional) – Dictionary holding the number of nodes for each node type. (default: None)

  • sparse (bool, optional) – If set to True, gradients w.r.t. to the weight matrix will be sparse. (default: False)

forward(node_type, batch=None)[source]

Returns the embeddings for the nodes in subset of type node_type.

loader(**kwargs)[source]
loss(pos_rw, neg_rw)[source]

Computes the loss given positive and negative random walks.

neg_sample(batch)[source]
pos_sample(batch)[source]
reset_parameters()[source]
sample(batch)[source]
test(train_z, train_y, test_z, test_y, solver='lbfgs', multi_class='auto', *args, **kwargs)[source]

Evaluates latent space quality via a logistic regression downstream task.

class Node2Vec(edge_index, embedding_dim, walk_length, context_size, walks_per_node=1, p=1, q=1, num_negative_samples=1, num_nodes=None, sparse=False)[source]

The Node2Vec model from the “node2vec: Scalable Feature Learning for Networks” paper where random walks of length walk_length are sampled in a given graph, and node embeddings are learned via negative sampling optimization.

Note

For an example of using Node2Vec, see examples/node2vec.py.

Parameters
  • edge_index (LongTensor) – The edge indices.

  • embedding_dim (int) – The size of each embedding vector.

  • walk_length (int) – The walk length.

  • context_size (int) – The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes.

  • walks_per_node (int, optional) – The number of walks to sample for each node. (default: 1)

  • p (float, optional) – Likelihood of immediately revisiting a node in the walk. (default: 1)

  • q (float, optional) – Control parameter to interpolate between breadth-first strategy and depth-first strategy (default: 1)

  • num_negative_samples (int, optional) – The number of negative samples to use for each positive sample. (default: 1)

  • num_nodes (int, optional) – The number of nodes. (default: None)

  • sparse (bool, optional) – If set to True, gradients w.r.t. to the weight matrix will be sparse. (default: False)

forward(batch=None)[source]

Returns the embeddings for the nodes in batch.

loader(**kwargs)[source]
loss(pos_rw, neg_rw)[source]

Computes the loss given positive and negative random walks.

neg_sample(batch)[source]
pos_sample(batch)[source]
reset_parameters()[source]
sample(batch)[source]
test(train_z, train_y, test_z, test_y, solver='lbfgs', multi_class='auto', *args, **kwargs)[source]

Evaluates latent space quality via a logistic regression downstream task.

class RENet(num_nodes, num_rels, hidden_channels, seq_len, num_layers=1, dropout=0.0, bias=True)[source]

The Recurrent Event Network model from the “Recurrent Event Network for Reasoning over Temporal Knowledge Graphs” paper

\[f_{\mathbf{\Theta}}(\mathbf{e}_s, \mathbf{e}_r, \mathbf{h}^{(t-1)}(s, r))\]

based on a RNN encoder

\[\mathbf{h}^{(t)}(s, r) = \textrm{RNN}(\mathbf{e}_s, \mathbf{e}_r, g(\mathcal{O}^{(t)}_r(s)), \mathbf{h}^{(t-1)}(s, r))\]

where \(\mathbf{e}_s\) and \(\mathbf{e}_r\) denote entity and relation embeddings, and \(\mathcal{O}^{(t)}_r(s)\) represents the set of objects interacted with subject \(s\) under relation \(r\) at timestamp \(t\). This model implements \(g\) as the Mean Aggregator and \(f_{\mathbf{\Theta}}\) as a linear projection.

Parameters
  • num_nodes (int) – The number of nodes in the knowledge graph.

  • num_rels (int) – The number of relations in the knowledge graph.

  • hidden_channels (int) – Hidden size of node and relation embeddings.

  • seq_len (int) – The sequence length of past events.

  • num_layers (int, optional) – The number of recurrent layers. (default: 1)

  • dropout (float) – If non-zero, introduces a dropout layer before the final prediction. (default: 0.)

  • bias (bool, optional) – If set to False, all layers will not learn an additive bias. (default: True)

forward(data)[source]

Given a data batch, computes the forward pass.

Parameters

data (torch_geometric.data.Data) – The input data, holding subject sub, relation rel and object obj information with shape [batch_size]. In addition, data needs to hold history information for subjects, given by a vector of node indices h_sub and their relative timestamps h_sub_t and batch assignments h_sub_batch. The same information must be given for objects (h_obj, h_obj_t, h_obj_batch).

static pre_transform(seq_len)[source]

Precomputes history objects

\[\{ \mathcal{O}^{(t-k-1)}_r(s), \ldots, \mathcal{O}^{(t-1)}_r(s) \}\]

of a torch_geometric.datasets.icews.EventDataset with \(k\) denoting the sequence length seq_len.

reset_parameters()[source]
test(logits, y)[source]

Given ground-truth y, computes Mean Reciprocal Rank (MRR) and Hits at 1/3/10.

class SchNet(hidden_channels=128, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0, readout='add', dipole=False, mean=None, std=None, atomref=None)[source]

The continuous-filter convolutional neural network SchNet from the “SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions” paper that uses the interactions blocks of the form

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),\]

here \(h_{\mathbf{\Theta}}\) denotes an MLP and \(\mathbf{e}_{j,i}\) denotes the interatomic distances between atoms.

Note

For an example of using a pretrained SchNet variant, see examples/qm9_pretrained_schnet.py.

Parameters
  • hidden_channels (int, optional) – Hidden embedding size. (default: 128)

  • num_filters (int, optional) – The number of filters to use. (default: 128)

  • num_interactions (int, optional) – The number of interaction blocks. (default: 6)

  • num_gaussians (int, optional) – The number of gaussians \(\mu\). (default: 50)

  • cutoff (float, optional) – Cutoff distance for interatomic interactions. (default: 10.0)

  • readout (string, optional) – Whether to apply "add" or "mean" global aggregation. (default: "add")

  • dipole (bool, optional) – If set to True, will use the magnitude of the dipole moment to make the final prediction, e.g., for target 0 of torch_geometric.datasets.QM9. (default: False)

  • mean (float, optional) – The mean of the property to predict. (default: None)

  • std (float, optional) – The standard deviation of the property to predict. (default: None)

  • atomref (torch.Tensor, optional) – The reference of single-atom properties. Expects a vector of shape (max_atomic_number, ).

forward(z, pos, batch=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

static from_qm9_pretrained(root, dataset, target)[source]
reset_parameters()[source]
url = 'http://www.quantum-machine.org/datasets/trained_schnet_models.zip'
class SignedGCN(in_channels, hidden_channels, num_layers, lamb=5, bias=True)[source]

The signed graph convolutional network model from the “Signed Graph Convolutional Network” paper. Internally, this module uses the torch_geometric.nn.conv.SignedConv operator.

Parameters
  • in_channels (int) – Size of each input sample.

  • hidden_channels (int) – Size of each hidden sample.

  • num_layers (int) – Number of layers.

  • lamb (float, optional) – Balances the contributions of the overall objective. (default: 5)

  • bias (bool, optional) – If set to False, all layers will not learn an additive bias. (default: True)

create_spectral_features(pos_edge_index, neg_edge_index, num_nodes=None)[source]

Creates in_channels spectral node features based on positive and negative edges.

Parameters
  • pos_edge_index (LongTensor) – The positive edge indices.

  • neg_edge_index (LongTensor) – The negative edge indices.

  • num_nodes (int, optional) – The number of nodes, i.e. max_val + 1 of pos_edge_index and neg_edge_index. (default: None)

discriminate(z, edge_index)[source]

Given node embeddings z, classifies the link relation between node pairs edge_index to be either positive, negative or non-existent.

Parameters
  • x (Tensor) – The input node features.

  • edge_index (LongTensor) – The edge indices.

forward(x, pos_edge_index, neg_edge_index)[source]

Computes node embeddings z based on positive edges pos_edge_index and negative edges neg_edge_index.

Parameters
  • x (Tensor) – The input node features.

  • pos_edge_index (LongTensor) – The positive edge indices.

  • neg_edge_index (LongTensor) – The negative edge indices.

loss(z, pos_edge_index, neg_edge_index)[source]

Computes the overall objective.

Parameters
  • z (Tensor) – The node embeddings.

  • pos_edge_index (LongTensor) – The positive edge indices.

  • neg_edge_index (LongTensor) – The negative edge indices.

neg_embedding_loss(z, neg_edge_index)[source]

Computes the triplet loss between negative node pairs and sampled non-node pairs.

Parameters
  • z (Tensor) – The node embeddings.

  • neg_edge_index (LongTensor) – The negative edge indices.

nll_loss(z, pos_edge_index, neg_edge_index)[source]

Computes the discriminator loss based on node embeddings z, and positive edges pos_edge_index and negative nedges neg_edge_index.

Parameters
  • z (Tensor) – The node embeddings.

  • pos_edge_index (LongTensor) – The positive edge indices.

  • neg_edge_index (LongTensor) – The negative edge indices.

pos_embedding_loss(z, pos_edge_index)[source]

Computes the triplet loss between positive node pairs and sampled non-node pairs.

Parameters
  • z (Tensor) – The node embeddings.

  • pos_edge_index (LongTensor) – The positive edge indices.

reset_parameters()[source]
split_edges(edge_index, test_ratio=0.2)[source]

Splits the edges edge_index into train and test edges.

Parameters
  • edge_index (LongTensor) – The edge indices.

  • test_ratio (float, optional) – The ratio of test edges. (default: 0.2)

test(z, pos_edge_index, neg_edge_index)[source]

Evaluates node embeddings z on positive and negative test edges by computing AUC and F1 scores.

Parameters
  • z (Tensor) – The node embeddings.

  • pos_edge_index (LongTensor) – The positive edge indices.

  • neg_edge_index (LongTensor) – The negative edge indices.

class VGAE(encoder, decoder=None)[source]

The Variational Graph Auto-Encoder model from the “Variational Graph Auto-Encoders” paper.

Parameters
encode(*args, **kwargs)[source]
kl_loss(mu=None, logstd=None)[source]

Computes the KL loss, either for the passed arguments mu and logstd, or based on latent variables from last encoding.

Parameters
  • mu (Tensor, optional) – The latent space for \(\mu\). If set to None, uses the last computation of \(mu\). (default: None)

  • logstd (Tensor, optional) – The latent space for \(\log\sigma\). If set to None, uses the last computation of \(\log\sigma^2\).(default: None)

reparametrize(mu, logstd)[source]

DataParallel Layers

class DataParallel(module, device_ids=None, output_device=None)[source]

Implements data parallelism at the module level.

This container parallelizes the application of the given module by splitting a list of torch_geometric.data.Data objects and copying them as torch_geometric.data.Batch objects to each device. In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.

The batch size should be larger than the number of GPUs used.

The parallelized module must have its parameters and buffers on device_ids[0].

Note

You need to use the torch_geometric.data.DataListLoader for this module.

Parameters
  • module (Module) – Module to be parallelized.

  • device_ids (list of int or torch.device) – CUDA devices. (default: all devices)

  • output_device (int or torch.device) – Device location of output. (default: device_ids[0])

forward(data_list)[source]
scatter(data_list, device_ids)[source]