torch_geometric.nn.conv.GravNetConv
- class GravNetConv(in_channels: int, out_channels: int, space_dimensions: int, propagate_dimensions: int, k: int, num_workers: Optional[int] = None, **kwargs)[source]
Bases:
MessagePassing
The GravNet operator from the “Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks” paper, where the graph is dynamically constructed using nearest neighbors. The neighbors are constructed in a learnable low-dimensional projection of the feature space. A second projection of the input feature space is then propagated from the neighbors to each vertex using distance weights that are derived by applying a Gaussian function to the distances.
- Parameters:
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – The number of output channels.
space_dimensions (int) – The dimensionality of the space used to construct the neighbors; referred to as \(S\) in the paper.
propagate_dimensions (int) – The number of features to be propagated between the vertices; referred to as \(F_{\textrm{LR}}\) in the paper.
k (int) – The number of nearest neighbors.
**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))\) if bipartite, batch vector \((|\mathcal{V}|)\) or \(((|\mathcal{V}_s|), (|\mathcal{V}_t|))\) if bipartite (optional)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite