# Source code for torch_geometric.transforms.feature_propagation

from torch import Tensor

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import is_torch_sparse_tensor, to_torch_csc_tensor

[docs]@functional_transform('feature_propagation')
class FeaturePropagation(BaseTransform):
r"""The feature propagation operator from the "On the Unreasonable
Effectiveness of Feature propagation in Learning on Graphs with Missing
Node Features" <https://arxiv.org/abs/2111.12128>_ paper
(functional name: :obj:feature_propagation).

.. math::
\mathbf{X}^{(0)} &= (1 - \mathbf{M}) \cdot \mathbf{X}

\mathbf{X}^{(\ell + 1)} &= \mathbf{X}^{(0)} + \mathbf{M} \cdot
(\mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \mathbf{X}^{(\ell)})

where missing node features are inferred by known features via propagation.

.. code-block:: python

from torch_geometric.transforms import FeaturePropagation

data = transform(data)

Args:
:math:\mathbf{M} \in {\{ 0, 1 \}}^{N\times F} indicating missing
node features.
num_iterations (int, optional): The number of propagations.
(default: :obj:40)
"""
def __init__(self, missing_mask: Tensor, num_iterations: int = 40) -> None:
self.num_iterations = num_iterations

def forward(self, data: Data) -> Data:
assert data.x is not None
assert data.edge_index is not None or data.adj_t is not None

gcn_norm = torch_geometric.nn.conv.gcn_conv.gcn_norm

if data.edge_index is not None:
edge_weight = data.edge_attr
if 'edge_weight' in data:
edge_weight = data.edge_weight
edge_index=data.edge_index,
edge_attr=edge_weight,
size=data.size(0),
).t()
else:

x = data.x.clone()

out = x
for _ in range(self.num_iterations):