from typing import Callable, Optional
from torch_geometric.typing import Adj, OptTensor
import torch
from torch import Tensor
import torch.nn.functional as F
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
[docs]class LabelPropagation(MessagePassing):
r"""The label propagation operator from the `"Learning from Labeled and
Unlabeled Datawith Label Propagation"
<http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf>`_ paper
.. math::
\mathbf{Y}^{\prime} = \alpha \cdot \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2} \mathbf{Y} + (1 - \alpha) \mathbf{Y},
where unlabeled data is inferred by labeled data via propagation.
Args:
num_layers (int): The number of propagations.
alpha (float): The :math:`\alpha` coefficient.
"""
def __init__(self, num_layers: int, alpha: float):
super(LabelPropagation, self).__init__(aggr='add')
self.num_layers = num_layers
self.alpha = alpha
[docs] @torch.no_grad()
def forward(
self, y: Tensor, edge_index: Adj, mask: Optional[Tensor] = None,
edge_weight: OptTensor = None,
post_step: Callable = lambda y: y.clamp_(0., 1.)
) -> Tensor:
""""""
if y.dtype == torch.long:
y = F.one_hot(y.view(-1)).to(torch.float)
out = y
if mask is not None:
out = torch.zeros_like(y)
out[mask] = y[mask]
if isinstance(edge_index, SparseTensor) and not edge_index.has_value():
edge_index = gcn_norm(edge_index, add_self_loops=False)
elif isinstance(edge_index, Tensor) and edge_weight is None:
edge_index, edge_weight = gcn_norm(edge_index, num_nodes=y.size(0),
add_self_loops=False)
res = (1 - self.alpha) * out
for _ in range(self.num_layers):
# propagate_type: (y: Tensor, edge_weight: OptTensor)
out = self.propagate(edge_index, x=out, edge_weight=edge_weight,
size=None)
out.mul_(self.alpha).add_(res)
out = post_step(out)
return out
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return matmul(adj_t, x, reduce=self.aggr)
def __repr__(self):
return '{}(num_layers={}, alpha={})'.format(self.__class__.__name__,
self.num_layers,
self.alpha)