torch_geometric.transforms.FeaturePropagation
- class FeaturePropagation(missing_mask: Tensor, num_iterations: int = 40)[source]
Bases:
BaseTransform
The feature propagation operator from the “On the Unreasonable Effectiveness of Feature propagation in Learning on Graphs with Missing Node Features” paper (functional name:
feature_propagation
).\[ \begin{align}\begin{aligned}\mathbf{X}^{(0)} &= (1 - \mathbf{M}) \cdot \mathbf{X}\\\mathbf{X}^{(\ell + 1)} &= \mathbf{X}^{(0)} + \mathbf{M} \cdot (\mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \mathbf{X}^{(\ell)})\end{aligned}\end{align} \]where missing node features are inferred by known features via propagation.
from torch_geometric.transforms import FeaturePropagation transform = FeaturePropagation(missing_mask=torch.isnan(data.x)) data = transform(data)
- Parameters:
missing_mask (torch.Tensor) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{N\times F}\) indicating missing node features.
num_iterations (int, optional) – The number of propagations. (default:
40
)