torch_geometric.nn.models.CorrectAndSmooth

class CorrectAndSmooth(num_correction_layers: int, correction_alpha: float, num_smoothing_layers: int, smoothing_alpha: float, autoscale: bool = True, scale: float = 1.0)[source]

Bases: Module

The correct and smooth (C&S) post-processing model from the “Combining Label Propagation And Simple Models Out-performs Graph Neural Networks” paper, where soft predictions \(\mathbf{Z}\) (obtained from a simple base predictor) are first corrected based on ground-truth training label information \(\mathbf{Y}\) and residual propagation.

\[\begin{split}\mathbf{e}^{(0)}_i &= \begin{cases} \mathbf{y}_i - \mathbf{z}_i, & \text{if }i \text{ is training node,}\\ \mathbf{0}, & \text{else} \end{cases}\end{split}\]
\[ \begin{align}\begin{aligned}\mathbf{E}^{(\ell)} &= \alpha_1 \mathbf{D}^{-1/2}\mathbf{A} \mathbf{D}^{-1/2} \mathbf{E}^{(\ell - 1)} + (1 - \alpha_1) \mathbf{E}^{(\ell - 1)}\\\mathbf{\hat{Z}} &= \mathbf{Z} + \gamma \cdot \mathbf{E}^{(L_1)},\end{aligned}\end{align} \]

where \(\gamma\) denotes the scaling factor (either fixed or automatically determined), and then smoothed over the graph via label propagation

\[\begin{split}\mathbf{\hat{z}}^{(0)}_i &= \begin{cases} \mathbf{y}_i, & \text{if }i\text{ is training node,}\\ \mathbf{\hat{z}}_i, & \text{else} \end{cases}\end{split}\]
\[\mathbf{\hat{Z}}^{(\ell)} = \alpha_2 \mathbf{D}^{-1/2}\mathbf{A} \mathbf{D}^{-1/2} \mathbf{\hat{Z}}^{(\ell - 1)} + (1 - \alpha_2) \mathbf{\hat{Z}}^{(\ell - 1)}\]

to obtain the final prediction \(\mathbf{\hat{Z}}^{(L_2)}\).

Note

For an example of using the C&S model, see examples/correct_and_smooth.py.

Parameters:
  • num_correction_layers (int) – The number of propagations \(L_1\).

  • correction_alpha (float) – The \(\alpha_1\) coefficient.

  • num_smoothing_layers (int) – The number of propagations \(L_2\).

  • smoothing_alpha (float) – The \(\alpha_2\) coefficient.

  • autoscale (bool, optional) – If set to True, will automatically determine the scaling factor \(\gamma\). (default: True)

  • scale (float, optional) – The scaling factor \(\gamma\), in case autoscale = False. (default: 1.0)

forward(y_soft: Tensor, *args) Tensor[source]

Applies both correct() and smooth().

correct(y_soft: Tensor, y_true: Tensor, mask: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]

Forward pass.

Parameters:
  • y_soft (torch.Tensor) – The soft predictions \(\mathbf{Z}\) obtained from a simple base predictor.

  • y_true (torch.Tensor) – The ground-truth label information \(\mathbf{Y}\) of training nodes.

  • mask (torch.Tensor) – A mask or index tensor denoting which nodes were used for training.

  • edge_index (torch.Tensor or SparseTensor) – The edge connectivity.

  • edge_weight (torch.Tensor, optional) – The edge weights. (default: None)

smooth(y_soft: Tensor, y_true: Tensor, mask: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]

Forward pass.

Parameters:
  • y_soft (torch.Tensor) – The corrected predictions \(\mathbf{Z}\) obtained from correct().

  • y_true (torch.Tensor) – The ground-truth label information \(\mathbf{Y}\) of training nodes.

  • mask (torch.Tensor) – A mask or index tensor denoting which nodes were used for training.

  • edge_index (torch.Tensor or SparseTensor) – The edge connectivity.

  • edge_weight (torch.Tensor, optional) – The edge weights. (default: None)