torch_geometric.nn¶
Contents
Convolutional Layers¶
-
class
MessagePassing
(aggr='add', flow='source_to_target', node_dim=0)[source]¶ Base class for creating message passing layers
\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right),\]where \(\square\) denotes a differentiable, permutation invariant function, e.g., sum, mean or max, and \(\gamma_{\mathbf{\Theta}}\) and \(\phi_{\mathbf{\Theta}}\) denote differentiable functions such as MLPs. See here for the accompanying tutorial.
Parameters: - aggr (string, optional) – The aggregation scheme to use
(
"add"
,"mean"
or"max"
). (default:"add"
) - flow (string, optional) – The flow direction of message passing
(
"source_to_target"
or"target_to_source"
). (default:"source_to_target"
) - node_dim (int, optional) – The axis along which to propagate.
(default:
0
)
-
aggregate
(inputs, index, dim_size)[source]¶ Aggregates messages from neighbors as \(\square_{j \in \mathcal{N}(i)}\).
By default, delegates call to scatter functions that support “add”, “mean” and “max” operations specified in
__init__()
by theaggr
argument.
-
message
(x_j)[source]¶ Constructs messages to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in \((j,i) \in \mathcal{E}\) if
flow="source_to_target"
and \((i,j) \in \mathcal{E}\) ifflow="target_to_source"
. Can take any argument which was initially passed topropagate()
. In addition, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
-
propagate
(edge_index, size=None, **kwargs)[source]¶ The initial call to start propagating messages.
Parameters: - edge_index (Tensor) – The indices of a general (sparse) assignment
matrix with shape
[N, M]
(can be directed or undirected). - size (list or tuple, optional) – The size
[N, M]
of the assignment matrix. If set toNone
, the size will be automatically inferred and assumed to be quadratic. (default:None
) - **kwargs – Any additional data which is needed to construct and aggregate messages, and to update node embeddings.
- edge_index (Tensor) – The indices of a general (sparse) assignment
matrix with shape
-
update
(inputs)[source]¶ Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to
propagate()
.
- aggr (string, optional) – The aggregation scheme to use
(
-
class
GCNConv
(in_channels, out_channels, improved=False, cached=False, bias=True, normalize=True, **kwargs)[source]¶ The graph convolutional operator from the “Semi-supervised Classification with Graph Convolutional Networks” paper
\[\mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},\]where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix.
Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- improved (bool, optional) – If set to
True
, the layer computes \(\mathbf{\hat{A}}\) as \(\mathbf{A} + 2\mathbf{I}\). (default:False
) - cached (bool, optional) – If set to
True
, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - normalize (bool, optional) – Whether to add self-loops and apply
symmetric normalization. (default:
True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
ChebConv
(in_channels, out_channels, K, normalization='sym', bias=True, **kwargs)[source]¶ The chebyshev spectral graph convolutional operator from the “Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering” paper
\[\mathbf{X}^{\prime} = \sum_{k=0}^{K-1} \mathbf{Z}^{(k)} \cdot \mathbf{\Theta}^{(k)}\]where \(\mathbf{Z}^{(k)}\) is computed recursively by
\[ \begin{align}\begin{aligned}\mathbf{Z}^{(0)} &= \mathbf{X}\\\mathbf{Z}^{(1)} &= \mathbf{\hat{L}} \cdot \mathbf{X}\\\mathbf{Z}^{(k)} &= 2 \cdot \mathbf{\hat{L}} \cdot \mathbf{Z}^{(k-1)} - \mathbf{Z}^{(k-2)}\end{aligned}\end{align} \]and \(\mathbf{\hat{L}}\) denotes the scaled and normalized Laplacian \(\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}\).
Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- K (int) – Chebyshev filter size, i.e. number of hops \(K\).
- normalization (str, optional) –
The normalization scheme for the graph Laplacian (default:
"sym"
):1.
None
: No normalization \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)2.
"sym"
: Symmetric normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)3.
"rw"
: Random-walk normalization \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)You need to pass
lambda_max
to theforward()
method of this operator in case the normalization is non-symmetric.lambda_max
should be atorch.Tensor
of size[num_graphs]
in a mini-batch scenario and a scalar when operating on single graphs. You can pre-computelambda_max
via thetorch_geometric.transforms.LaplacianLambdaMax
transform. - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
SAGEConv
(in_channels, out_channels, normalize=False, bias=True, **kwargs)[source]¶ The GraphSAGE operator from the “Inductive Representation Learning on Large Graphs” paper
\[ \begin{align}\begin{aligned}\mathbf{\hat{x}}_i &= \mathbf{\Theta} \cdot \mathrm{mean}_{j \in \mathcal{N(i) \cup \{ i \}}}(\mathbf{x}_j)\\\mathbf{x}^{\prime}_i &= \frac{\mathbf{\hat{x}}_i} {\| \mathbf{\hat{x}}_i \|_2}.\end{aligned}\end{align} \]Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- normalize (bool, optional) – If set to
True
, output features will be \(\ell_2\)-normalized. (default:False
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
GraphConv
(in_channels, out_channels, aggr='add', bias=True, **kwargs)[source]¶ The graph neural network operator from the “Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks” paper
\[\mathbf{x}^{\prime}_i = \mathbf{\Theta}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{\Theta}_2 \mathbf{x}_j.\]Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- aggr (string, optional) – The aggregation scheme to use
(
"add"
,"mean"
,"max"
). (default:"add"
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
GatedGraphConv
(out_channels, num_layers, aggr='add', bias=True, **kwargs)[source]¶ The gated graph convolution operator from the “Gated Graph Sequence Neural Networks” paper
\[ \begin{align}\begin{aligned}\mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0}\\\mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)}\\\mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, \mathbf{h}_i^{(l)})\end{aligned}\end{align} \]up to representation \(\mathbf{h}_i^{(L)}\). The number of input channels of \(\mathbf{x}_i\) needs to be less or equal than
out_channels
.Parameters: - out_channels (int) – Size of each input sample.
- num_layers (int) – The sequence length \(L\).
- aggr (string, optional) – The aggregation scheme to use
(
"add"
,"mean"
,"max"
). (default:"add"
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
GATConv
(in_channels, out_channels, heads=1, concat=True, negative_slope=0.2, dropout=0, bias=True, **kwargs)[source]¶ The graph attentional operator from the “Graph Attention Networks” paper
\[\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},\]where the attention coefficients \(\alpha_{i,j}\) are computed as
\[\alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] \right)\right)}.\]Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- heads (int, optional) – Number of multi-head-attentions.
(default:
1
) - concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
) - negative_slope (float, optional) – LeakyReLU angle of the negative
slope. (default:
0.2
) - dropout (float, optional) – Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default:
0
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
AGNNConv
(requires_grad=True, **kwargs)[source]¶ Graph attentional propagation layer from the “Attention-based Graph Neural Network for Semi-Supervised Learning” paper
\[\mathbf{X}^{\prime} = \mathbf{P} \mathbf{X},\]where the propagation matrix \(\mathbf{P}\) is computed as
\[P_{i,j} = \frac{\exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_j))} {\sum_{k \in \mathcal{N}(i)\cup \{ i \}} \exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_k))}\]with trainable parameter \(\beta\).
Parameters:
-
class
TAGConv
(in_channels, out_channels, K=3, bias=True, **kwargs)[source]¶ - The topology adaptive graph convolutional networks operator from the
- “Topology Adaptive Graph Convolutional Networks” paper
\[\mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k},\]where \(\mathbf{A}\) denotes the adjacency matrix and \(D_{ii} = \sum_{j=0} A_{ij}\) its diagonal degree matrix.
Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- K (int, optional) – Number of hops \(K\). (default:
3
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
GINConv
(nn, eps=0, train_eps=False, **kwargs)[source]¶ The graph isomorphism operator from the “How Powerful are Graph Neural Networks?” paper
\[\mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right),\]here \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. a MLP.
Parameters: - nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that
maps node features
x
of shape[-1, in_channels]
to shape[-1, out_channels]
, e.g., defined bytorch.nn.Sequential
. - eps (float, optional) – (Initial) \(\epsilon\) value.
(default:
0
) - train_eps (bool, optional) – If set to
True
, \(\epsilon\) will be a trainable parameter. (default:False
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that
maps node features
-
class
ARMAConv
(in_channels, out_channels, num_stacks=1, num_layers=1, shared_weights=False, act=<function relu>, dropout=0, bias=True)[source]¶ The ARMA graph convolutional operator from the “Graph Neural Networks with Convolutional ARMA Filters” paper
\[\mathbf{X}^{\prime} = \frac{1}{K} \sum_{k=1}^K \mathbf{X}_k^{(T)},\]with \(\mathbf{X}_k^{(T)}\) being recursively defined by
\[\mathbf{X}_k^{(t+1)} = \sigma \left( \mathbf{\hat{L}} \mathbf{X}_k^{(t)} \mathbf{W} + \mathbf{X}^{(0)} \mathbf{V} \right),\]where \(\mathbf{\hat{L}} = \mathbf{I} - \mathbf{L} = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\) denotes the modified Laplacian \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\).
Parameters: - in_channels (int) – Size of each input sample \(\mathbf{x}^{(t)}\).
- out_channels (int) – Size of each output sample \(\mathbf{x}^{(t+1)}\).
- num_stacks (int, optional) – Number of parallel stacks \(K\).
(default:
1
). - num_layers (int, optional) – Number of layers \(T\).
(default:
1
) - act (callable, optional) – Activation function \(\sigma\).
(default:
torch.nn.functional.ReLU()
) - shared_weights (int, optional) – If set to
True
the layers in each stack will share the same parameters. (default:False
) - dropout (float, optional) – Dropout probability of the skip connection.
(default:
0
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)
-
class
SGConv
(in_channels, out_channels, K=1, cached=False, bias=True, **kwargs)[source]¶ The simple graph convolutional operator from the “Simplifying Graph Convolutional Networks” paper
\[\mathbf{X}^{\prime} = {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X} \mathbf{\Theta},\]where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix.
Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- K (int, optional) – Number of hops \(K\). (default:
1
) - cached (bool, optional) – If set to
True
, the layer will cache the computation of \({\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^K\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
APPNP
(K, alpha, bias=True, **kwargs)[source]¶ The approximate personalized propagation of neural predictions layer from the “Predict then Propagate: Graph Neural Networks meet Personalized PageRank” paper
\[ \begin{align}\begin{aligned}\mathbf{X}^{(0)} &= \mathbf{X}\\\mathbf{X}^{(k)} &= (1 - \alpha) \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X}^{(k-1)} + \alpha \mathbf{X}^{(0)}\\\mathbf{X}^{\prime} &= \mathbf{X}^{(K)},\end{aligned}\end{align} \]where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix.
Parameters:
-
class
RGCNConv
(in_channels, out_channels, num_relations, num_bases, root_weight=True, bias=True, **kwargs)[source]¶ The relational graph convolutional operator from the “Modeling Relational Data with Graph Convolutional Networks” paper
\[\mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)} \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,\]where \(\mathcal{R}\) denotes the set of relations, i.e. edge types. Edge type needs to be a one-dimensional
torch.long
tensor which stores a relation identifier \(\in \{ 0, \ldots, |\mathcal{R}| - 1\}\) for each edge.Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- num_relations (int) – Number of relations.
- num_bases (int) – Number of bases used for basis-decomposition.
- root_weight (bool, optional) – If set to
False
, the layer will not add transformed root node features to the output. (default:True
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
SignedConv
(in_channels, out_channels, first_aggr, bias=True, **kwargs)[source]¶ The signed graph convolutional operator from the “Signed Graph Convolutional Network” paper
\[ \begin{align}\begin{aligned}\mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w , \mathbf{x}_v \right]\\\mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{neg})} \left[ \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w , \mathbf{x}_v \right]\end{aligned}\end{align} \]if
first_aggr
is set toTrue
, and\[ \begin{align}\begin{aligned}\mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{pos})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{neg})} , \mathbf{x}_v^{(\textrm{pos})} \right]\\\mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{neg})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{pos})} , \mathbf{x}_v^{(\textrm{neg})} \right]\end{aligned}\end{align} \]otherwise. In case
first_aggr
isFalse
, the layer expectsx
to be a tensor wherex[:, :in_channels]
denotes the positive node features \(\mathbf{X}^{(\textrm{pos})}\) andx[:, in_channels:]
denotes the negative node features \(\mathbf{X}^{(\textrm{neg})}\).Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- first_aggr (bool) – Denotes which aggregation formula to use.
- bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
DNAConv
(channels, heads=1, groups=1, dropout=0, cached=False, bias=True, **kwargs)[source]¶ The dynamic neighborhood aggregation operator from the “Just Jump: Towards Dynamic Neighborhood Aggregation in Graph Neural Networks” paper
\[\mathbf{x}_v^{(t)} = h_{\mathbf{\Theta}}^{(t)} \left( \mathbf{x}_{v \leftarrow v}^{(t)}, \left\{ \mathbf{x}_{v \leftarrow w}^{(t)} : w \in \mathcal{N}(v) \right\} \right)\]based on (multi-head) dot-product attention
\[\mathbf{x}_{v \leftarrow w}^{(t)} = \textrm{Attention} \left( \mathbf{x}^{(t-1)}_v \, \mathbf{\Theta}_Q^{(t)}, [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_K^{(t)}, \, [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_V^{(t)} \right)\]with \(\mathbf{\Theta}_Q^{(t)}, \mathbf{\Theta}_K^{(t)}, \mathbf{\Theta}_V^{(t)}\) denoting (grouped) projection matrices for query, key and value information, respectively. \(h^{(t)}_{\mathbf{\Theta}}\) is implemented as a non-trainable version of
torch_geometric.nn.conv.GCNConv
.Note
In contrast to other layers, this operator expects node features as shape
[num_nodes, num_layers, channels]
.Parameters: - channels (int) – Size of each input/output sample.
- heads (int, optional) – Number of multi-head-attentions.
(default:
1
) - groups (int, optional) – Number of groups to use for all linear
projections. (default:
1
) - dropout (float, optional) – Dropout probability of attention
coefficients. (default:
0
) - cached (bool, optional) – If set to
True
, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
PointConv
(local_nn=None, global_nn=None, **kwargs)[source]¶ The PointNet set layer from the “PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation” and “PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space” papers
\[\mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \mathbf{p}_j - \mathbf{p}_i) \right),\]where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote neural networks, .i.e. MLPs, and \(\mathbf{P} \in \mathbb{R}^{N \times D}\) defines the position of each point.
Parameters: - local_nn (torch.nn.Module, optional) – A neural network
\(h_{\mathbf{\Theta}}\) that maps node features
x
and relative spatial coordinatespos_j - pos_i
of shape[-1, in_channels + num_dimensions]
to shape[-1, out_channels]
, e.g., defined bytorch.nn.Sequential
. (default:None
) - global_nn (torch.nn.Module, optional) – A neural network
\(\gamma_{\mathbf{\Theta}}\) that maps aggregated node features
of shape
[-1, out_channels]
to shape[-1, final_out_channels]
, e.g., defined bytorch.nn.Sequential
. (default:None
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- local_nn (torch.nn.Module, optional) – A neural network
\(h_{\mathbf{\Theta}}\) that maps node features
-
class
GMMConv
(in_channels, out_channels, dim, kernel_size, separate_gaussians=False, aggr='mean', root_weight=True, bias=True, **kwargs)[source]¶ The gaussian mixture model convolutional operator from the “Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs” paper
\[\mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \frac{1}{K} \sum_{k=1}^K \mathbf{w}_k(\mathbf{e}_{i,j}) \odot \mathbf{\Theta}_k \mathbf{x}_j,\]where
\[\mathbf{w}_k(\mathbf{e}) = \exp \left( -\frac{1}{2} {\left( \mathbf{e} - \mathbf{\mu}_k \right)}^{\top} \Sigma_k^{-1} \left( \mathbf{e} - \mathbf{\mu}_k \right) \right)\]denotes a weighting function based on trainable mean vector \(\mathbf{\mu}_k\) and diagonal covariance matrix \(\mathbf{\Sigma}_k\).
Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- dim (int) – Pseudo-coordinate dimensionality.
- kernel_size (int) – Number of kernels \(K\).
- separate_gaussians (bool, optional) – If set to
True
, will learn separate GMMs for every pair of input and output channel, inspired by traditional CNNs. (default:False
) - aggr (string, optional) – The aggregation operator to use
(
"add"
,"mean"
,"max"
). (default:"mean"
) - root_weight (bool, optional) – If set to
False
, the layer will not add transformed root node features to the output. (default:True
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
SplineConv
(in_channels, out_channels, dim, kernel_size, is_open_spline=True, degree=1, aggr='mean', root_weight=True, bias=True, **kwargs)[source]¶ The spline-based convolutional operator from the “SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels” paper
\[\mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),\]where \(h_{\mathbf{\Theta}}\) denotes a kernel function defined over the weighted B-Spline tensor product basis.
Note
Pseudo-coordinates must lay in the fixed interval \([0, 1]\) for this method to work as intended.
Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- dim (int) – Pseudo-coordinate dimensionality.
- kernel_size (int or [int]) – Size of the convolving kernel.
- is_open_spline (bool or [bool], optional) – If set to
False
, the operator will use a closed B-spline basis in this dimension. (defaultTrue
) - degree (int, optional) – B-spline basis degrees. (default:
1
) - aggr (string, optional) – The aggregation operator to use
(
"add"
,"mean"
,"max"
). (default:"mean"
) - root_weight (bool, optional) – If set to
False
, the layer will not add transformed root node features to the output. (default:True
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
NNConv
(in_channels, out_channels, nn, aggr='add', root_weight=True, bias=True, **kwargs)[source]¶ The continuous kernel-based convolutional operator from the “Neural Message Passing for Quantum Chemistry” paper. This convolution is also known as the edge-conditioned convolution from the “Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs” paper (see
torch_geometric.nn.conv.ECConv
for an alias):\[\mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),\]where \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. a MLP.
Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that
maps edge features
edge_attr
of shape[-1, num_edge_features]
to shape[-1, in_channels * out_channels]
, e.g., defined bytorch.nn.Sequential
. - aggr (string, optional) – The aggregation scheme to use
(
"add"
,"mean"
,"max"
). (default:"add"
) - root_weight (bool, optional) – If set to
False
, the layer will not add the transformed root node features to the output. (default:True
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
ECConv
¶ alias of
torch_geometric.nn.conv.nn_conv.NNConv
-
class
CGConv
(channels, dim, aggr='add', bias=True, **kwargs)[source]¶ The crystal graph convolutional operator from the “Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties” paper
\[\mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \sigma \left( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right) \odot g \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s \right)\]where \(\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{i,j} ]\) denotes the concatenation of central node features, neighboring node features and edge features. In addition, \(\sigma\) and \(g\) denote the sigmoid and softplus functions, respectively.
Parameters: - channels (int) – Size of each input sample.
- dim (int) – Edge feature dimensionality.
- aggr (string, optional) – The aggregation operator to use
(
"add"
,"mean"
,"max"
). (default:"add"
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
EdgeConv
(nn, aggr='max', **kwargs)[source]¶ The edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}}(\mathbf{x}_i \, \Vert \, \mathbf{x}_j - \mathbf{x}_i),\]where \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. a MLP.
Parameters: - nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that
maps pair-wise concatenated node features
x
of shape[-1, 2 * in_channels]
to shape[-1, out_channels]
, e.g., defined bytorch.nn.Sequential
. - aggr (string, optional) – The aggregation scheme to use
(
"add"
,"mean"
,"max"
). (default:"max"
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that
maps pair-wise concatenated node features
-
class
DynamicEdgeConv
(nn, k, aggr='max', **kwargs)[source]¶ The dynamic edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper (see
torch_geometric.nn.conv.EdgeConv
), where the graph is dynamically constructed using nearest neighbors in the feature space.Parameters: - nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that
maps pair-wise concatenated node features
x
of shape :obj:`[-1, 2 * in_channels] to shape[-1, out_channels]
, e.g. defined bytorch.nn.Sequential
. - k (int) – Number of nearest neighbors.
- aggr (string) – The aggregation operator to use (
"add"
,"mean"
,"max"
). (default:"max"
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that
maps pair-wise concatenated node features
-
class
XConv
(in_channels, out_channels, dim, kernel_size, hidden_channels=None, dilation=1, bias=True, **kwargs)[source]¶ The convolutional operator on \(\mathcal{X}\)-transformed points from the “PointCNN: Convolution On X-Transformed Points” paper
\[\mathbf{x}^{\prime}_i = \mathrm{Conv}\left(\mathbf{K}, \gamma_{\mathbf{\Theta}}(\mathbf{P}_i - \mathbf{p}_i) \times \left( h_\mathbf{\Theta}(\mathbf{P}_i - \mathbf{p}_i) \, \Vert \, \mathbf{x}_i \right) \right),\]where \(\mathbf{K}\) and \(\mathbf{P}_i\) denote the trainable filter and neighboring point positions of \(\mathbf{x}_i\), respectively. \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) describe neural networks, i.e. MLPs, where \(h_{\mathbf{\Theta}}\) individually lifts each point into a higher-dimensional space, and \(\gamma_{\mathbf{\Theta}}\) computes the \(\mathcal{X}\)- transformation matrix based on all points in a neighborhood.
Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- dim (int) – Point cloud dimensionality.
- kernel_size (int) – Size of the convolving kernel, i.e. number of neighbors including self-loops.
- hidden_channels (int, optional) – Output size of
\(h_{\mathbf{\Theta}}\), i.e. dimensionality of lifted
points. If set to
None
, will be automatically set toin_channels / 4
. (default:None
) - dilation (int, optional) – The factor by which the neighborhood is
extended, from which
kernel_size
neighbors are then uniformly sampled. Can be interpreted as the dilation rate of classical convolutional operators. (default:1
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_cluster.knn_graph
.
-
class
PPFConv
(local_nn=None, global_nn=None, **kwargs)[source]¶ The PPFNet operator from the “PPFNet: Global Context Aware Local Features for Robust 3D Point Matching” paper
\[\mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \| \mathbf{d_{j,i}} \|, \angle(\mathbf{n}_i, \mathbf{d_{j,i}}), \angle(\mathbf{n}_j, \mathbf{d_{j,i}}), \angle(\mathbf{n}_i, \mathbf{n}_j) \right)\]where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote neural networks, .i.e. MLPs, which takes in node features and
torch_geometric.transforms.PointPairFeatures
.Parameters: - local_nn (torch.nn.Module, optional) – A neural network
\(h_{\mathbf{\Theta}}\) that maps node features
x
and relative spatial coordinatespos_j - pos_i
of shape[-1, in_channels + num_dimensions]
to shape[-1, out_channels]
, e.g., defined bytorch.nn.Sequential
. (default:None
) - global_nn (torch.nn.Module, optional) – A neural network
\(\gamma_{\mathbf{\Theta}}\) that maps aggregated node features
of shape
[-1, out_channels]
to shape[-1, final_out_channels]
, e.g., defined bytorch.nn.Sequential
. (default:None
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
forward
(x, pos, norm, edge_index)[source]¶ Parameters: - x (Tensor) – The node feature matrix. Allowed to be
None
. - pos (Tensor or tuple) – The node position matrix. Either given as tensor for use in general message passing or as tuple for use in message passing in bipartite graphs.
- norm (Tensor or tuple) – The normal vectors of each node. Either given as tensor for use in general message passing or as tuple for use in message passing in bipartite graphs.
- edge_index (LongTensor) – The edge indices.
- x (Tensor) – The node feature matrix. Allowed to be
- local_nn (torch.nn.Module, optional) – A neural network
\(h_{\mathbf{\Theta}}\) that maps node features
-
class
FeaStConv
(in_channels, out_channels, heads=1, bias=True, **kwargs)[source]¶ The (translation-invariant) feature-steered convolutional operator from the “FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis” paper
\[\mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} \sum_{h=1}^H q_h(\mathbf{x}_i, \mathbf{x}_j) \mathbf{W}_h \mathbf{x}_j\]with \(q_h(\mathbf{x}_i, \mathbf{x}_j) = \mathrm{softmax}_j (\mathbf{u}_h^{\top} (\mathbf{x}_j - \mathbf{x}_i) + c_h)\), where \(H\) denotes the number of attention heads, and \(\mathbf{W}_h\), \(\mathbf{u}_h\) and \(c_h\) are trainable parameters.
Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- heads (int, optional) – Number of attention heads \(H\).
(default:
1
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
) - **kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
-
class
HypergraphConv
(in_channels, out_channels, use_attention=False, heads=1, concat=True, negative_slope=0.2, dropout=0, bias=True)[source]¶ The hypergraph convolutional operator from the “Hypergraph Convolution and Hypergraph Attention” paper
\[\mathbf{X}^{\prime} = \mathbf{D}^{-1} \mathbf{H} \mathbf{W} \mathbf{B}^{-1} \mathbf{H}^{\top} \mathbf{X} \mathbf{\Theta}\]where \(\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}\) is the incidence matrix, \(\mathbf{W}\) is the diagonal hyperedge weight matrix, and \(\mathbf{D}\) and \(\mathbf{B}\) are the corresponding degree matrices.
Parameters: - in_channels (int) – Size of each input sample.
- out_channels (int) – Size of each output sample.
- use_attention (bool, optional) – If set to
True
, attention will be added to this layer. (default:False
) - heads (int, optional) – Number of multi-head-attentions.
(default:
1
) - concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
) - negative_slope (float, optional) – LeakyReLU angle of the negative
slope. (default:
0.2
) - dropout (float, optional) – Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default:
0
) - bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)
-
class
MetaLayer
(edge_model=None, node_model=None, global_model=None)[source]¶ A meta layer for building any kind of graph network, inspired by the “Relational Inductive Biases, Deep Learning, and Graph Networks” paper.
A graph network takes a graph as input and returns an updated graph as output (with same connectivity). The input graph has node features
x
, edge featuresedge_attr
as well as global-level featuresu
. The output graph has the same structure, but updated features.Edge features, node features as well as global features are updated by calling the modules
edge_model
,node_model
andglobal_model
, respectively.To allow for batch-wise graph processing, all callable functions take an additional argument
batch
, which determines the assignment of edges or nodes to their specific graphs.Parameters: - edge_model (Module, optional) – A callable which updates a graph’s edge
features based on its source and target node features, its current
edge features and its global features. (default:
None
) - node_model (Module, optional) – A callable which updates a graph’s node
features based on its current node features, its graph
connectivity, its edge features and its global features.
(default:
None
) - global_model (Module, optional) – A callable which updates a graph’s global features based on its node features, its graph connectivity, its edge features and its current global features.
from torch.nn import Sequential as Seq, Linear as Lin, ReLU from torch_scatter import scatter_mean from torch_geometric.nn import MetaLayer class EdgeModel(torch.nn.Module): def __init__(self): super(EdgeModel, self).__init__() self.edge_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) def forward(self, src, dest, edge_attr, u, batch): # source, target: [E, F_x], where E is the number of edges. # edge_attr: [E, F_e] # u: [B, F_u], where B is the number of graphs. # batch: [E] with max entry B - 1. out = torch.cat([src, dest, edge_attr, u[batch]], 1) return self.edge_mlp(out) class NodeModel(torch.nn.Module): def __init__(self): super(NodeModel, self).__init__() self.node_mlp_1 = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) self.node_mlp_2 = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. row, col = edge_index out = torch.cat([x[row], edge_attr], dim=1) out = self.node_mlp_1(out) out = scatter_mean(out, col, dim=0, dim_size=x.size(0)) out = torch.cat([x, out, u[batch]], dim=1) return self.node_mlp_2(out) class GlobalModel(torch.nn.Module): def __init__(self): super(GlobalModel, self).__init__() self.global_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. out = torch.cat([u, scatter_mean(x, batch, dim=0)], dim=1) return self.global_mlp(out) op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel()) x, edge_attr, u = op(x, edge_index, edge_attr, u, batch)
- edge_model (Module, optional) – A callable which updates a graph’s edge
features based on its source and target node features, its current
edge features and its global features. (default:
Dense Convolutional Layers¶
-
class
DenseGCNConv
(in_channels, out_channels, improved=False, bias=True)[source]¶ See
torch_geometric.nn.conv.GCNConv
.-
forward
(x, adj, mask=None, add_loop=True)[source]¶ Parameters: - x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).
- adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.
- mask (BoolTensor, optional) – Mask matrix
\(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating
the valid nodes for each graph. (default:
None
) - add_loop (bool, optional) – If set to
False
, the layer will not automatically add self-loops to the adjacency matrices. (default:True
)
-
-
class
DenseSAGEConv
(in_channels, out_channels, normalize=False, bias=True)[source]¶ See
torch_geometric.nn.conv.SAGEConv
.-
forward
(x, adj, mask=None, add_loop=True)[source]¶ Parameters: - x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).
- adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.
- mask (BoolTensor, optional) – Mask matrix
\(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating
the valid nodes for each graph. (default:
None
) - add_loop (bool, optional) – If set to
False
, the layer will not automatically add self-loops to the adjacency matrices. (default:True
)
-
-
class
DenseGraphConv
(in_channels, out_channels, aggr='add', bias=True)[source]¶ See
torch_geometric.nn.conv.GraphConv
.-
forward
(x, adj, mask=None)[source]¶ Parameters: - x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).
- adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.
- mask (BoolTensor, optional) – Mask matrix
\(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating
the valid nodes for each graph. (default:
None
)
-
-
class
DenseGINConv
(nn, eps=0, train_eps=False)[source]¶ See
torch_geometric.nn.conv.GINConv
.Return type: Tensor
-
forward
(x, adj, mask=None, add_loop=True)[source]¶ Parameters: - x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).
- adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.
- mask (BoolTensor, optional) – Mask matrix
\(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating
the valid nodes for each graph. (default:
None
) - add_loop (bool, optional) – If set to
False
, the layer will not automatically add self-loops to the adjacency matrices. (default:True
)
-
Normalization Layers¶
-
class
BatchNorm
(in_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]¶ Applies batch normalization over a batch of node features as described in the “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift” paper
\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]Parameters: - in_channels (int) – Size of each input sample.
- eps (float, optional) – A value added to the denominator for numerical
stability. (default:
1e-5
) - momentum (float, optional) – The value used for the running mean and
running variance computation. (default:
0.1
) - affine (bool, optional) – If set to
True
, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default:True
) - track_running_stats (bool, optional) – If set to
True
, this module tracks the running mean and variance, and when set toFalse
, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default:True
)
-
class
InstanceNorm
(in_channels, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)[source]¶ Applies instance normalization over each individual example in a batch of node features as described in the “Instance Normalization: The Missing Ingredient for Fast Stylization” paper
\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]Parameters: - in_channels (int) – Size of each input sample.
- eps (float, optional) – A value added to the denominator for numerical
stability. (default:
1e-5
) - momentum (float, optional) – The value used for the running mean and
running variance computation. (default:
0.1
) - affine (bool, optional) – If set to
True
, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default:False
) - track_running_stats (bool, optional) – If set to
True
, this module tracks the running mean and variance, and when set toFalse
, this module does not track such statistics and always uses instance statistics in both training and eval modes. (default:False
)
Global Pooling Layers¶
-
global_add_pool
(x, batch, size=None)[source]¶ Returns batch-wise graph-level-outputs by adding node features across the node dimension, so that for a single graph \(\mathcal{G}_i\) its output is computed by
\[\mathbf{r}_i = \sum_{n=1}^{N_i} \mathbf{x}_n\]Parameters: - x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).
- batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.
- size (int, optional) – Batch-size \(B\).
Automatically calculated if not given. (default:
None
)
Return type: Tensor
-
global_mean_pool
(x, batch, size=None)[source]¶ Returns batch-wise graph-level-outputs by averaging node features across the node dimension, so that for a single graph \(\mathcal{G}_i\) its output is computed by
\[\mathbf{r}_i = \frac{1}{N_i} \sum_{n=1}^{N_i} \mathbf{x}_n\]Parameters: - x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).
- batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.
- size (int, optional) – Batch-size \(B\).
Automatically calculated if not given. (default:
None
)
Return type: Tensor
-
global_max_pool
(x, batch, size=None)[source]¶ Returns batch-wise graph-level-outputs by taking the channel-wise maximum across the node dimension, so that for a single graph \(\mathcal{G}_i\) its output is computed by
\[\mathbf{r}_i = \mathrm{max}_{n=1}^{N_i} \, \mathbf{x}_n\]Parameters: - x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).
- batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.
- size (int, optional) – Batch-size \(B\).
Automatically calculated if not given. (default:
None
)
Return type: Tensor
-
global_sort_pool
(x, batch, k)[source]¶ The global pooling operator from the “An End-to-End Deep Learning Architecture for Graph Classification” paper, where node features are first sorted individually and then sorted in descending order based on their last features. The first \(k\) nodes form the output of the layer.
Parameters: - x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
- batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.
- k (int) – The number of nodes to hold for each graph.
Return type: Tensor
-
class
GlobalAttention
(gate_nn, nn=None)[source]¶ Global soft attention layer from the “Gated Graph Sequence Neural Networks” paper
\[\mathbf{r}_i = \sum_{n=1}^{N_i} \mathrm{softmax} \left( h_{\mathrm{gate}} ( \mathbf{x}_n ) \right) \odot h_{\mathbf{\Theta}} ( \mathbf{x}_n ),\]where \(h_{\mathrm{gate}} \colon \mathbb{R}^F \to \mathbb{R}\) and \(h_{\mathbf{\Theta}}\) denote neural networks, i.e. MLPS.
Parameters: - gate_nn (torch.nn.Module) – A neural network \(h_{\mathrm{gate}}\)
that computes attention scores by mapping node features
x
of shape[-1, in_channels]
to shape[-1, 1]
, e.g., defined bytorch.nn.Sequential
. - nn (torch.nn.Module, optional) – A neural network
\(h_{\mathbf{\Theta}}\) that maps node features
x
of shape[-1, in_channels]
to shape[-1, out_channels]
before combining them with the attention scores, e.g., defined bytorch.nn.Sequential
. (default:None
)
- gate_nn (torch.nn.Module) – A neural network \(h_{\mathrm{gate}}\)
that computes attention scores by mapping node features
-
class
Set2Set
(in_channels, processing_steps, num_layers=1)[source]¶ The global pooling operator based on iterative content-based attention from the “Order Matters: Sequence to sequence for sets” paper
\[ \begin{align}\begin{aligned}\mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1})\\\alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t)\\\mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i\\\mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t,\end{aligned}\end{align} \]where \(\mathbf{q}^{*}_T\) defines the output of the layer with twice the dimensionality as the input.
Parameters: - in_channels (int) – Size of each input sample.
- processing_steps (int) – Number of iterations \(T\).
- num_layers (int, optional) – Number of recurrent layers, .e.g, setting
num_layers=2
would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. (default:1
)
Pooling Layers¶
-
class
TopKPooling
(in_channels, ratio=0.5, min_score=None, multiplier=1, nonlinearity=<built-in method tanh of type object>)[source]¶ \(\mathrm{top}_k\) pooling operator from the “Graph U-Nets”, “Towards Sparse Hierarchical Graph Classifiers” and “Understanding Attention and Generalization in Graph Neural Networks” papers
if min_score \(\tilde{\alpha}\) is None:
\[ \begin{align}\begin{aligned}\mathbf{y} &= \frac{\mathbf{X}\mathbf{p}}{\| \mathbf{p} \|}\\\mathbf{i} &= \mathrm{top}_k(\mathbf{y})\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}\end{aligned}\end{align} \]if min_score \(\tilde{\alpha}\) is a value in [0, 1]:
\[ \begin{align}\begin{aligned}\mathbf{y} &= \mathrm{softmax}(\mathbf{X}\mathbf{p})\\\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},\end{aligned}\end{align} \]where nodes are dropped based on a learnable projection score \(\mathbf{p}\).
Parameters: - in_channels (int) – Size of each input sample.
- ratio (float) – Graph pooling ratio, which is used to compute
\(k = \lceil \mathrm{ratio} \cdot N \rceil\).
This value is ignored if min_score is not None.
(default:
0.5
) - min_score (float, optional) – Minimal node score \(\tilde{\alpha}\)
which is used to compute indices of pooled nodes
\(\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}\).
When this value is not
None
, theratio
argument is ignored. (default:None
) - multiplier (float, optional) – Coefficient by which features gets
multiplied after pooling. This can be useful for large graphs and
when
min_score
is used. (default:1
) - nonlinearity (torch.nn.functional, optional) – The nonlinearity to use.
(default:
torch.tanh
)
-
class
SAGPooling
(in_channels, ratio=0.5, GNN=<class 'torch_geometric.nn.conv.graph_conv.GraphConv'>, min_score=None, multiplier=1, nonlinearity=<built-in method tanh of type object>, **kwargs)[source]¶ The self-attention pooling operator from the “Self-Attention Graph Pooling” and “Understanding Attention and Generalization in Graph Neural Networks” papers
if
min_score
\(\tilde{\alpha}\) isNone
:\[ \begin{align}\begin{aligned}\mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A})\\\mathbf{i} &= \mathrm{top}_k(\mathbf{y})\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}\end{aligned}\end{align} \]if
min_score
\(\tilde{\alpha}\) is a value in [0, 1]:\[ \begin{align}\begin{aligned}\mathbf{y} &= \mathrm{softmax}(\textrm{GNN}(\mathbf{X},\mathbf{A}))\\\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},\end{aligned}\end{align} \]where nodes are dropped based on a learnable projection score \(\mathbf{p}\). Projections scores are learned based on a graph neural network layer.
Parameters: - in_channels (int) – Size of each input sample.
- ratio (float) – Graph pooling ratio, which is used to compute
\(k = \lceil \mathrm{ratio} \cdot N \rceil\).
This value is ignored if min_score is not None.
(default:
0.5
) - GNN (torch.nn.Module, optional) – A graph neural network layer for
calculating projection scores (one of
torch_geometric.nn.conv.GraphConv
,torch_geometric.nn.conv.GCNConv
,torch_geometric.nn.conv.GATConv
ortorch_geometric.nn.conv.SAGEConv
). (default:torch_geometric.nn.conv.GraphConv
) - min_score (float, optional) – Minimal node score \(\tilde{\alpha}\)
which is used to compute indices of pooled nodes
\(\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}\).
When this value is not
None
, theratio
argument is ignored. (default:None
) - multiplier (float, optional) – Coefficient by which features gets
multiplied after pooling. This can be useful for large graphs and
when
min_score
is used. (default:1
) - nonlinearity (torch.nn.functional, optional) – The nonlinearity to use.
(default:
torch.tanh
) - **kwargs (optional) – Additional parameters for initializing the graph neural network layer.
-
class
EdgePooling
(in_channels, edge_score_method=None, dropout=0, add_to_edge_score=0.5)[source]¶ The edge pooling operator from the “Towards Graph Pooling by Edge Contraction” and “Edge Contraction Pooling for Graph Neural Networks” papers.
In short, a score is computed for each edge. Edges are contracted iteratively according to that score unless one of their nodes has already been part of a contracted edge.
To duplicate the configuration from the “Towards Graph Pooling by Edge Contraction” paper, use either
EdgePooling.compute_edge_score_softmax()
orEdgePooling.compute_edge_score_tanh()
, and setadd_to_edge_score
to0
.To duplicate the configuration from the “Edge Contraction Pooling for Graph Neural Networks” paper, set
dropout
to0.2
.Parameters: - in_channels (int) – Size of each input sample.
- edge_score_method (function, optional) – The function to apply
to compute the edge score from raw edge scores. By default,
this is the softmax over all incoming edges for each node.
This function takes in a
raw_edge_score
tensor of shape[num_nodes]
, anedge_index
tensor and the number of nodesnum_nodes
, and produces a new tensor of the same size asraw_edge_score
describing normalized edge scores. Included functions areEdgePooling.compute_edge_score_softmax()
,EdgePooling.compute_edge_score_tanh()
, andEdgePooling.compute_edge_score_sigmoid()
. (default:EdgePooling.compute_edge_score_softmax()
) - dropout (float, optional) – The probability with
which to drop edge scores during training. (default:
0
) - add_to_edge_score (float, optional) – This is added to each
computed edge score. Adding this greatly helps with unpool
stability. (default:
0.5
)
-
forward
(x, edge_index, batch)[source]¶ Forward computation which computes the raw edge score, normalizes it, and merges the edges.
Parameters: - x (Tensor) – The node features.
- edge_index (LongTensor) – The edge indices.
- batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.
- Return types:
- x (Tensor) - The pooled node features.
- edge_index (LongTensor) - The coarsened edge indices.
- batch (LongTensor) - The coarsened batch vector.
- unpool_info (unpool_description) - Information that is
consumed by
EdgePooling.unpool()
for unpooling.
-
unpool
(x, unpool_info)[source]¶ Unpools a previous edge pooling step.
For unpooling,
x
should be of same shape as those produced by this layer’sforward()
function. Then, it will produce an unpooledx
in addition toedge_index
andbatch
.Parameters: - x (Tensor) – The node features.
- unpool_info (unpool_description) – Information that has
been produced by
EdgePooling.forward()
.
- Return types:
- x (Tensor) - The unpooled node features.
- edge_index (LongTensor) - The new edge indices.
- batch (LongTensor) - The new batch vector.
-
unpool_description
¶ alias of
UnpoolDescription
-
max_pool
(cluster, data, transform=None)[source]¶ Pools and coarsens a graph given by the
torch_geometric.data.Data
object according to the clustering defined incluster
. All nodes within the same cluster will be represented as one node. Final node features are defined by the maximum features of all nodes within the same cluster, node positions are averaged and edge indices are defined to be the union of the edge indices of all nodes within the same cluster.Parameters: - cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.
- data (Data) – Graph data object.
- transform (callable, optional) – A function/transform that takes in the
coarsened and pooled
torch_geometric.data.Data
object and returns a transformed version. (default:None
)
Return type:
-
avg_pool
(cluster, data, transform=None)[source]¶ Pools and coarsens a graph given by the
torch_geometric.data.Data
object according to the clustering defined incluster
. Final node features are defined by the average features of all nodes within the same cluster. Seetorch_geometric.nn.pool.max_pool()
for more details.Parameters: - cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.
- data (Data) – Graph data object.
- transform (callable, optional) – A function/transform that takes in the
coarsened and pooled
torch_geometric.data.Data
object and returns a transformed version. (default:None
)
Return type:
-
max_pool_x
(cluster, x, batch, size=None)[source]¶ Max-Pools node features according to the clustering defined in
cluster
.Parameters: - cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.
- x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).
- batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.
- size (int, optional) – The maximum number of clusters in a single
example. This property is useful to obtain a batch-wise dense
representation, e.g. for applying FC layers, but should only be
used if the size of the maximum number of clusters per example is
known in advance. (default:
None
)
Return type: (
Tensor
,LongTensor
) ifsize
isNone
, elseTensor
-
avg_pool_x
(cluster, x, batch, size=None)[source]¶ Average pools node features according to the clustering defined in
cluster
. Seetorch_geometric.nn.pool.max_pool_x()
for more details.Parameters: - cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.
- x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).
- batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.
- size (int, optional) – The maximum number of clusters in a single
example. (default:
None
)
Return type: (
Tensor
,LongTensor
) ifsize
isNone
, elseTensor
-
graclus
(edge_index, weight=None, num_nodes=None)[source]¶ A greedy clustering algorithm from the “Weighted Graph Cuts without Eigenvectors: A Multilevel Approach” paper of picking an unmarked vertex and matching it with one of its unmarked neighbors (that maximizes its edge weight). The GPU algoithm is adapted from the “A GPU Algorithm for Greedy Graph Matching” paper.
Parameters: Return type: LongTensor
-
voxel_grid
(pos, batch, size, start=None, end=None)[source]¶ Voxel grid pooling from the, e.g., Dynamic Edge-Conditioned Filters in Convolutional Networks on Graphs paper, which overlays a regular grid of user-defined size over a point cloud and clusters all points within the same voxel.
Parameters: - pos (Tensor) – Node position matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times D}\).
- batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.
- size (float or [float] or Tensor) – Size of a voxel (in each dimension).
- start (float or [float] or Tensor, optional) – Start coordinates of the
grid (in each dimension). If set to
None
, will be set to the minimum coordinates found inpos
. (default:None
) - end (float or [float] or Tensor, optional) – End coordinates of the grid
(in each dimension). If set to
None
, will be set to the maximum coordinates found inpos
. (default:None
)
Return type: LongTensor
-
fps
(x, batch=None, ratio=0.5, random_start=True)[source]¶ “A sampling algorithm from the “PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space” paper, which iteratively samples the most distant point with regard to the rest points.
Parameters: - x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
- batch (LongTensor, optional) – Batch vector
\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each
node to a specific example. (default:
None
) - ratio (float, optional) – Sampling ratio. (default:
0.5
) - random_start (bool, optional) – If set to
False
, use the first node in \(\mathbf{X}\) as starting node. (default: obj:True)
Return type: LongTensor
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) >>> batch = torch.tensor([0, 0, 0, 0]) >>> index = fps(x, batch, ratio=0.5)
-
knn
(x, y, k, batch_x=None, batch_y=None, cosine=False)[source]¶ Finds for each element in
y
thek
nearest points inx
.Parameters: - x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
- y (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{M \times F}\).
- k (int) – The number of neighbors.
- batch_x (LongTensor, optional) – Batch vector
\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each
node to a specific example. (default:
None
) - batch_y (LongTensor, optional) – Batch vector
\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M\), which assigns each
node to a specific example. (default:
None
) - cosine (boolean, optional) – If
True
, will use the cosine distance instead of euclidean distance to find nearest neighbors. (default:False
)
Return type: LongTensor
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) >>> batch_x = torch.tensor([0, 0, 0, 0]) >>> y = torch.Tensor([[-1, 0], [1, 0]]) >>> batch_x = torch.tensor([0, 0]) >>> assign_index = knn(x, y, 2, batch_x, batch_y)
-
knn_graph
(x, k, batch=None, loop=False, flow='source_to_target', cosine=False)[source]¶ Computes graph edges to the nearest
k
points.Parameters: - x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
- k (int) – The number of neighbors.
- batch (LongTensor, optional) – Batch vector
\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each
node to a specific example. (default:
None
) - loop (bool, optional) – If
True
, the graph will contain self-loops. (default:False
) - flow (string, optional) – The flow direction when using in combination
with message passing (
"source_to_target"
or"target_to_source"
). (default:"source_to_target"
) - cosine (boolean, optional) – If
True
, will use the cosine distance instead of euclidean distance to find nearest neighbors. (default:False
)
Return type: LongTensor
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) >>> batch = torch.tensor([0, 0, 0, 0]) >>> edge_index = knn_graph(x, k=2, batch=batch, loop=False)
-
radius
(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32)[source]¶ Finds for each element in
y
all points inx
within distancer
.Parameters: - x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
- y (Tensor) – Node feature matrix \(\mathbf{Y} \in \mathbb{R}^{M \times F}\).
- r (float) – The radius.
- batch_x (LongTensor, optional) – Batch vector
\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each
node to a specific example. (default:
None
) - batch_y (LongTensor, optional) – Batch vector
\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M\), which assigns each
node to a specific example. (default:
None
) - max_num_neighbors (int, optional) – The maximum number of neighbors to
return for each element in
y
. (default:32
)
Return type: LongTensor
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) >>> batch_x = torch.tensor([0, 0, 0, 0]) >>> y = torch.Tensor([[-1, 0], [1, 0]]) >>> batch_y = torch.tensor([0, 0]) >>> assign_index = radius(x, y, 1.5, batch_x, batch_y)
-
radius_graph
(x, r, batch=None, loop=False, max_num_neighbors=32, flow='source_to_target')[source]¶ Computes graph edges to all points within a given distance.
Parameters: - x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
- r (float) – The radius.
- batch (LongTensor, optional) – Batch vector
\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each
node to a specific example. (default:
None
) - loop (bool, optional) – If
True
, the graph will contain self-loops. (default:False
) - max_num_neighbors (int, optional) – The maximum number of neighbors to
return for each element in
y
. (default:32
) - flow (string, optional) – The flow direction when using in combination
with message passing (
"source_to_target"
or"target_to_source"
). (default:"source_to_target"
)
Return type: LongTensor
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) >>> batch = torch.tensor([0, 0, 0, 0]) >>> edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
-
nearest
(x, y, batch_x=None, batch_y=None)[source]¶ Clusters points in
x
together which are nearest to a given query point iny
.Parameters: - x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
- y (Tensor) – Node feature matrix \(\mathbf{Y} \in \mathbb{R}^{M \times F}\).
- batch_x (LongTensor, optional) – Batch vector
\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each
node to a specific example. (default:
None
) - batch_y (LongTensor, optional) – Batch vector
\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M\), which assigns each
node to a specific example. (default:
None
)
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) >>> batch_x = torch.tensor([0, 0, 0, 0]) >>> y = torch.Tensor([[-1, 0], [1, 0]]) >>> batch_y = torch.tensor([0, 0]) >>> cluster = nearest(x, y, batch_x, batch_y)
Dense Pooling Layers¶
-
dense_diff_pool
(x, adj, s, mask=None)[source]¶ Differentiable pooling operator from the “Hierarchical Graph Representation Learning with Differentiable Pooling” paper
\[ \begin{align}\begin{aligned}\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X}\\\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})\end{aligned}\end{align} \]based on dense learned assignments \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\). Returns pooled node feature matrix, coarsened adjacency matrix and two auxiliary objectives: (1) The link prediction loss
\[\mathcal{L}_{LP} = {\| \mathbf{A} - \mathrm{softmax}(\mathbf{S}) {\mathrm{softmax}(\mathbf{S})}^{\top} \|}_F,\]and the entropy regularization
\[\mathcal{L}_E = \frac{1}{N} \sum_{n=1}^N H(\mathbf{S}_n).\]Parameters: - x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\) with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).
- adj (Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\).
- s (Tensor) – Assignment tensor \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\) with number of clusters \(C\). The softmax does not have to be applied beforehand, since it is executed within this method.
- mask (BoolTensor, optional) – Mask matrix
\(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating
the valid nodes for each graph. (default:
None
)
Return type: (
Tensor
,Tensor
,Tensor
,Tensor
)
-
dense_mincut_pool
(x, adj, s, mask=None)[source]¶ MinCUt pooling operator from the “Mincut Pooling in Graph Neural Networks” paper
\[ \begin{align}\begin{aligned}\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{X}\\\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})\end{aligned}\end{align} \]based on dense learned assignments \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\). Returns pooled node feature matrix, coarsened symmetrically normalized adjacency matrix and two auxiliary objectives: (1) The minCUT loss
\[\mathcal{L}_c = - \frac{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{A} \mathbf{S})} {\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{D} \mathbf{S})}\]where \(\mathbf{D}\) is the degree matrix, and (2) the orthogonality loss
\[\mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}} {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}} \right\|}_F.\]Parameters: - x (Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\) with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).
- adj (Tensor) – Symmetrically normalized adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\).
- s (Tensor) – Assignment tensor \(\mathbf{S} \in \mathbb{R}^{B \times N \times C}\) with number of clusters \(C\). The softmax does not have to be applied beforehand, since it is executed within this method.
- mask (BoolTensor, optional) – Mask matrix
\(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating
the valid nodes for each graph. (default:
None
)
Return type: (
Tensor
,Tensor
,Tensor
,Tensor
)
Unpooling Layers¶
-
knn_interpolate
(x, pos_x, pos_y, batch_x=None, batch_y=None, k=3)[source]¶ The k-NN interpolation from the “PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space” paper. For each point \(y\) with position \(\mathbf{p}(y)\), its interpolated features \(\mathbf{f}(y)\) are given by
\[\mathbf{f}(y) = \frac{\sum_{i=1}^k w(x_i) \mathbf{f}(x_i)}{\sum_{i=1}^k w(x_i)} \textrm{, where } w(x_i) = \frac{1}{d(\mathbf{p}(y), \mathbf{p}(x_i))^2}\]and \(\{ x_1, \ldots, x_k \}\) denoting the \(k\) nearest points to \(y\).
Parameters: - x (Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
- pos_x (Tensor) – Node position matrix \(\in \mathbb{R}^{N \times d}\).
- pos_y (Tensor) – Upsampled node position matrix \(\in \mathbb{R}^{M \times d}\).
- batch_x (LongTensor, optional) – Batch vector
\(\mathbf{b_x} \in {\{ 0, \ldots, B-1\}}^N\), which assigns
each node from \(\mathbf{X}\) to a specific example.
(default:
None
) - batch_y (LongTensor, optional) – Batch vector
\(\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^N\), which assigns
each node from \(\mathbf{Y}\) to a specific example.
(default:
None
) - k (int, optional) – Number of neighbors. (default:
3
)
Models¶
-
class
JumpingKnowledge
(mode, channels=None, num_layers=None)[source]¶ The Jumping Knowledge layer aggregation module from the “Representation Learning on Graphs with Jumping Knowledge Networks” paper based on either concatenation (
"cat"
)\[\mathbf{x}_v^{(1)} \, \Vert \, \ldots \, \Vert \, \mathbf{x}_v^{(T)}\]max pooling (
"max"
)\[\max \left( \mathbf{x}_v^{(1)}, \ldots, \mathbf{x}_v^{(T)} \right)\]or weighted summation
\[\sum_{t=1}^T \alpha_v^{(t)} \mathbf{x}_v^{(t)}\]with attention scores \(\alpha_v^{(t)}\) obtained from a bi-directional LSTM (
"lstm"
).Parameters: - mode (string) – The aggregation scheme to use
(
"cat"
,"max"
or"lstm"
). - channels (int, optional) – The number of channels per representation.
Needs to be only set for LSTM-style aggregation.
(default:
None
) - num_layers (int, optional) – The number of layers to aggregate. Needs to
be only set for LSTM-style aggregation. (default:
None
)
- mode (string) – The aggregation scheme to use
(
-
class
Node2Vec
(num_nodes, embedding_dim, walk_length, context_size, walks_per_node=1, p=1, q=1, num_negative_samples=None)[source]¶ The Node2Vec model from the “node2vec: Scalable Feature Learning for Networks” paper where random walks of length
walk_length
are sampled in a given graph, and node embeddings are learned via negative sampling optimization.Parameters: - num_nodes (int) – The number of nodes.
- embedding_dim (int) – The size of each embedding vector.
- walk_length (int) – The walk length.
- context_size (int) – The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes.
- walks_per_node (int, optional) – The number of walks to sample for each
node. (default:
1
) - p (float, optional) – Likelihood of immediately revisiting a node in the
walk. (default:
1
) - q (float, optional) – Control parameter to interpolate between
breadth-first strategy and depth-first strategy (default:
1
) - num_negative_samples (int, optional) – The number of negative samples to
use for each node. If set to
None
, this parameter gets set tocontext_size - 1
. (default:None
)
-
class
DeepGraphInfomax
(hidden_channels, encoder, summary, corruption)[source]¶ The Deep Graph Infomax model from the “Deep Graph Infomax” paper based on user-defined encoder and summary model \(\mathcal{E}\) and \(\mathcal{R}\) respectively, and a corruption function \(\mathcal{C}\).
Parameters: - hidden_channels (int) – The latent space dimensionality.
- encoder (Module) – The encoder module \(\mathcal{E}\).
- summary (callable) – The readout function \(\mathcal{R}\).
- corruption (callable) – The corruption function \(\mathcal{C}\).
-
discriminate
(z, summary, sigmoid=True)[source]¶ Given the patch-summary pair
z
andsummary
, computes the probability scores assigned to this patch-summary pair.Parameters:
-
class
InnerProductDecoder
[source]¶ The inner product decoder from the “Variational Graph Auto-Encoders” paper
\[\sigma(\mathbf{Z}\mathbf{Z}^{\top})\]where \(\mathbf{Z} \in \mathbb{R}^{N \times d}\) denotes the latent space produced by the encoder.
-
class
GAE
(encoder, decoder=None)[source]¶ The Graph Auto-Encoder model from the “Variational Graph Auto-Encoders” paper based on user-defined encoder and decoder models.
Parameters: - encoder (Module) – The encoder module.
- decoder (Module, optional) – The decoder module. If set to
None
, will default to thetorch_geometric.nn.models.InnerProductDecoder
. (default:None
)
-
recon_loss
(z, pos_edge_index)[source]¶ Given latent variables
z
, computes the binary cross entropy loss for positive edgespos_edge_index
and negative sampled edges.Parameters: - z (Tensor) – The latent space \(\mathbf{Z}\).
- pos_edge_index (LongTensor) – The positive edges to train against.
-
split_edges
(data, val_ratio=0.05, test_ratio=0.1)[source]¶ Splits the edges of a
torch_geometric.data.Data
object into positve and negative train/val/test edges.Parameters:
-
test
(z, pos_edge_index, neg_edge_index)[source]¶ Given latent variables
z
, positive edgespos_edge_index
and negative edgesneg_edge_index
, computes area under the ROC curve (AUC) and average precision (AP) scores.Parameters: - z (Tensor) – The latent space \(\mathbf{Z}\).
- pos_edge_index (LongTensor) – The positive edges to evaluate against.
- neg_edge_index (LongTensor) – The negative edges to evaluate against.
-
class
VGAE
(encoder, decoder=None)[source]¶ The Variational Graph Auto-Encoder model from the “Variational Graph Auto-Encoders” paper.
Parameters: - encoder (Module) – The encoder module to compute \(\mu\) and \(\log\sigma^2\).
- decoder (Module, optional) – The decoder module. If set to
None
, will default to thetorch_geometric.nn.models.InnerProductDecoder
. (default:None
)
-
class
ARGA
(encoder, discriminator, decoder=None)[source]¶ The Adversarially Regularized Graph Auto-Encoder model from the “Adversarially Regularized Graph Autoencoder for Graph Embedding” paper. paper.
Parameters: - encoder (Module) – The encoder module.
- discriminator (Module) – The discriminator module.
- decoder (Module, optional) – The decoder module. If set to
None
, will default to thetorch_geometric.nn.models.InnerProductDecoder
. (default:None
)
-
discriminator_loss
(z)[source]¶ Computes the loss of the discriminator.
Parameters: z (Tensor) – The latent space \(\mathbf{Z}\).
-
class
ARGVA
(encoder, discriminator, decoder=None)[source]¶ The Adversarially Regularized Variational Graph Auto-Encoder model from the “Adversarially Regularized Graph Autoencoder for Graph Embedding” paper. paper.
Parameters: - encoder (Module) – The encoder module to compute \(\mu\) and \(\log\sigma^2\).
- discriminator (Module) – The discriminator module.
- decoder (Module, optional) – The decoder module. If set to
None
, will default to thetorch_geometric.nn.models.InnerProductDecoder
. (default:None
)
-
class
SignedGCN
(in_channels, hidden_channels, num_layers, lamb=5, bias=True)[source]¶ The signed graph convolutional network model from the “Signed Graph Convolutional Network” paper. Internally, this module uses the
torch_geometric.nn.conv.SignedConv
operator.Parameters: - in_channels (int) – Size of each input sample.
- hidden_channels (int) – Size of each hidden sample.
- num_layers (int) – Number of layers.
- lamb (float, optional) – Balances the contributions of the overall
objective. (default:
5
) - bias (bool, optional) – If set to
False
, all layers will not learn an additive bias. (default:True
)
-
create_spectral_features
(pos_edge_index, neg_edge_index, num_nodes=None)[source]¶ Creates
in_channels
spectral node features based on positive and negative edges.Parameters:
-
discriminate
(z, edge_index)[source]¶ Given node embeddings
z
, classifies the link relation between node pairsedge_index
to be either positive, negative or non-existent.Parameters: - x (Tensor) – The input node features.
- edge_index (LongTensor) – The edge indices.
-
forward
(x, pos_edge_index, neg_edge_index)[source]¶ Computes node embeddings
z
based on positive edgespos_edge_index
and negative edgesneg_edge_index
.Parameters: - x (Tensor) – The input node features.
- pos_edge_index (LongTensor) – The positive edge indices.
- neg_edge_index (LongTensor) – The negative edge indices.
-
loss
(z, pos_edge_index, neg_edge_index)[source]¶ Computes the overall objective.
Parameters: - z (Tensor) – The node embeddings.
- pos_edge_index (LongTensor) – The positive edge indices.
- neg_edge_index (LongTensor) – The negative edge indices.
-
neg_embedding_loss
(z, neg_edge_index)[source]¶ Computes the triplet loss between negative node pairs and sampled non-node pairs.
Parameters: - z (Tensor) – The node embeddings.
- neg_edge_index (LongTensor) – The negative edge indices.
-
nll_loss
(z, pos_edge_index, neg_edge_index)[source]¶ Computes the discriminator loss based on node embeddings
z
, and positive edgespos_edge_index
and negative nedgesneg_edge_index
.Parameters: - z (Tensor) – The node embeddings.
- pos_edge_index (LongTensor) – The positive edge indices.
- neg_edge_index (LongTensor) – The negative edge indices.
-
pos_embedding_loss
(z, pos_edge_index)[source]¶ Computes the triplet loss between positive node pairs and sampled non-node pairs.
Parameters: - z (Tensor) – The node embeddings.
- pos_edge_index (LongTensor) – The positive edge indices.
-
split_edges
(edge_index, test_ratio=0.2)[source]¶ Splits the edges
edge_index
into train and test edges.Parameters: - edge_index (LongTensor) – The edge indices.
- test_ratio (float, optional) – The ratio of test edges.
(default:
0.2
)
-
test
(z, pos_edge_index, neg_edge_index)[source]¶ Evaluates node embeddings
z
on positive and negative test edges by computing AUC and F1 scores.Parameters: - z (Tensor) – The node embeddings.
- pos_edge_index (LongTensor) – The positive edge indices.
- neg_edge_index (LongTensor) – The negative edge indices.
-
class
RENet
(num_nodes, num_rels, hidden_channels, seq_len, num_layers=1, dropout=0.0, bias=True)[source]¶ The Recurrent Event Network model from the “Recurrent Event Network for Reasoning over Temporal Knowledge Graphs” paper
\[f_{\mathbf{\Theta}}(\mathbf{e}_s, \mathbf{e}_r, \mathbf{h}^{(t-1)}(s, r))\]based on a RNN encoder
\[\mathbf{h}^{(t)}(s, r) = \textrm{RNN}(\mathbf{e}_s, \mathbf{e}_r, g(\mathcal{O}^{(t)}_r(s)), \mathbf{h}^{(t-1)}(s, r))\]where \(\mathbf{e}_s\) and \(\mathbf{e}_r\) denote entity and relation embeddings, and \(\mathcal{O}^{(t)}_r(s)\) represents the set of objects interacted with subject \(s\) under relation \(r\) at timestamp \(t\). This model implements \(g\) as the Mean Aggregator and \(f_{\mathbf{\Theta}}\) as a linear projection.
Parameters: - num_nodes (int) – The number of nodes in the knowledge graph.
- num_rels (int) – The number of relations in the knowledge graph.
- hidden_channels (int) – Hidden size of node and relation embeddings.
- seq_len (int) – The sequence length of past events.
- num_layers (int, optional) – The number of recurrent layers.
(default:
1
) - dropout (float) – If non-zero, introduces a dropout layer before the
final prediction. (default:
0.
) - bias (bool, optional) – If set to
False
, all layers will not learn an additive bias. (default:True
)
-
forward
(data)[source]¶ Given a
data
batch, computes the forward pass.Parameters: data (torch_geometric.data.Data) – The input data, holding subject sub
, relationrel
and objectobj
information with shape[batch_size]
. In addition,data
needs to hold history information for subjects, given by a vector of node indicesh_sub
and their relative timestampsh_sub_t
and batch assignmentsh_sub_batch
. The same information must be given for objects (h_obj
,h_obj_t
,h_obj_batch
).
-
class
GraphUNet
(in_channels, hidden_channels, out_channels, depth, pool_ratios=0.5, sum_res=True, act=<function relu>)[source]¶ The Graph U-Net model from the “Graph U-Nets” paper which implements a U-Net like architecture with graph pooling and unpooling operations.
Parameters: - in_channels (int) – Size of each input sample.
- hidden_channels (int) – Size of each hidden sample.
- out_channels (int) – Size of each output sample.
- depth (int) – The depth of the U-Net architecture.
- pool_ratios (float or [float], optional) – Graph pooling ratio for each
depth. (default:
0.5
) - sum_res (bool, optional) – If set to
False
, will use concatenation for integration of skip connections instead summation. (default:True
) - act (torch.nn.functional, optional) – The nonlinearity to use.
(default:
torch.nn.functional.relu
)
DataParallel Layers¶
-
class
DataParallel
(module, device_ids=None, output_device=None)[source]¶ Implements data parallelism at the module level.
This container parallelizes the application of the given
module
by splitting a list oftorch_geometric.data.Data
objects and copying them astorch_geometric.data.Batch
objects to each device. In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.The batch size should be larger than the number of GPUs used.
The parallelized
module
must have its parameters and buffers ondevice_ids[0]
.Note
You need to use the
torch_geometric.data.DataListLoader
for this module.Parameters: - module (Module) – Module to be parallelized.
- device_ids (list of int or torch.device) – CUDA devices. (default: all devices)
- output_device (int or torch.device) – Device location of output.
(default:
device_ids[0]
)