torch_geometric.transforms.FeaturePropagation

class FeaturePropagation(missing_mask: Tensor, num_iterations: int = 40)[source]

Bases: BaseTransform

The feature propagation operator from the “On the Unreasonable Effectiveness of Feature propagation in Learning on Graphs with Missing Node Features” paper (functional name: feature_propagation).

\[ \begin{align}\begin{aligned}\mathbf{X}^{(0)} &= (1 - \mathbf{M}) \cdot \mathbf{X}\\\mathbf{X}^{(\ell + 1)} &= \mathbf{X}^{(0)} + \mathbf{M} \cdot (\mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \mathbf{X}^{(\ell)})\end{aligned}\end{align} \]

where missing node features are inferred by known features via propagation.

from torch_geometric.transforms import FeaturePropagation

transform = FeaturePropagation(missing_mask=torch.isnan(data.x))
data = transform(data)
Parameters:
  • missing_mask (torch.Tensor) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{N\times F}\) indicating missing node features.

  • num_iterations (int, optional) – The number of propagations. (default: 40)