torch_geometric.nn.models.CorrectAndSmooth
- class CorrectAndSmooth(num_correction_layers: int, correction_alpha: float, num_smoothing_layers: int, smoothing_alpha: float, autoscale: bool = True, scale: float = 1.0)[source]
Bases:
Module
The correct and smooth (C&S) post-processing model from the “Combining Label Propagation And Simple Models Out-performs Graph Neural Networks” paper, where soft predictions \(\mathbf{Z}\) (obtained from a simple base predictor) are first corrected based on ground-truth training label information \(\mathbf{Y}\) and residual propagation.
\[\begin{split}\mathbf{e}^{(0)}_i &= \begin{cases} \mathbf{y}_i - \mathbf{z}_i, & \text{if }i \text{ is training node,}\\ \mathbf{0}, & \text{else} \end{cases}\end{split}\]\[ \begin{align}\begin{aligned}\mathbf{E}^{(\ell)} &= \alpha_1 \mathbf{D}^{-1/2}\mathbf{A} \mathbf{D}^{-1/2} \mathbf{E}^{(\ell - 1)} + (1 - \alpha_1) \mathbf{E}^{(\ell - 1)}\\\mathbf{\hat{Z}} &= \mathbf{Z} + \gamma \cdot \mathbf{E}^{(L_1)},\end{aligned}\end{align} \]where \(\gamma\) denotes the scaling factor (either fixed or automatically determined), and then smoothed over the graph via label propagation
\[\begin{split}\mathbf{\hat{z}}^{(0)}_i &= \begin{cases} \mathbf{y}_i, & \text{if }i\text{ is training node,}\\ \mathbf{\hat{z}}_i, & \text{else} \end{cases}\end{split}\]\[\mathbf{\hat{Z}}^{(\ell)} = \alpha_2 \mathbf{D}^{-1/2}\mathbf{A} \mathbf{D}^{-1/2} \mathbf{\hat{Z}}^{(\ell - 1)} + (1 - \alpha_2) \mathbf{\hat{Z}}^{(\ell - 1)}\]to obtain the final prediction \(\mathbf{\hat{Z}}^{(L_2)}\).
Note
For an example of using the C&S model, see examples/correct_and_smooth.py.
- Parameters:
num_correction_layers (int) – The number of propagations \(L_1\).
correction_alpha (float) – The \(\alpha_1\) coefficient.
num_smoothing_layers (int) – The number of propagations \(L_2\).
smoothing_alpha (float) – The \(\alpha_2\) coefficient.
autoscale (bool, optional) – If set to
True
, will automatically determine the scaling factor \(\gamma\). (default:True
)scale (float, optional) – The scaling factor \(\gamma\), in case
autoscale = False
. (default:1.0
)
- correct(y_soft: Tensor, y_true: Tensor, mask: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor [source]
Forward pass.
- Parameters:
y_soft (torch.Tensor) – The soft predictions \(\mathbf{Z}\) obtained from a simple base predictor.
y_true (torch.Tensor) – The ground-truth label information \(\mathbf{Y}\) of training nodes.
mask (torch.Tensor) – A mask or index tensor denoting which nodes were used for training.
edge_index (torch.Tensor or SparseTensor) – The edge connectivity.
edge_weight (torch.Tensor, optional) – The edge weights. (default:
None
)
- smooth(y_soft: Tensor, y_true: Tensor, mask: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor [source]
Forward pass.
- Parameters:
y_soft (torch.Tensor) – The corrected predictions \(\mathbf{Z}\) obtained from
correct()
.y_true (torch.Tensor) – The ground-truth label information \(\mathbf{Y}\) of training nodes.
mask (torch.Tensor) – A mask or index tensor denoting which nodes were used for training.
edge_index (torch.Tensor or SparseTensor) – The edge connectivity.
edge_weight (torch.Tensor, optional) – The edge weights. (default:
None
)