torch_geometric.nn.conv.WLConvContinuous
- class WLConvContinuous(**kwargs)[source]
Bases:
MessagePassing
The Weisfeiler Lehman operator from the “Wasserstein Weisfeiler-Lehman Graph Kernels” paper. Refinement is done though a degree-scaled mean aggregation and works on nodes with continuous attributes:
\[\mathbf{x}^{\prime}_i = \frac{1}{2}\big(\mathbf{x}_i + \frac{1}{\textrm{deg}(i)} \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{x}_j \big)\]where \(e_{j,i}\) denotes the edge weight from source node
j
to target nodei
(default:1
)- Parameters
**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F)\) or \(((|\mathcal{V_s}|, F), (|\mathcal{V_t}|, F))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F)\) or \((|\mathcal{V}_t|, F)\) if bipartite
- forward(x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Tensor, edge_weight: Optional[Tensor] = None, size: Optional[Tuple[int, int]] = None) Tensor [source]
Runs the forward pass of the module.
- reset_parameters()
Resets all learnable parameters of the module.