Source code for torch_geometric.nn.models.mask_label

import torch
from torch import Tensor

[docs]class MaskLabel(torch.nn.Module): r"""The label embedding and masking layer from the `"Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification" <>`_ paper. Here, node labels :obj:`y` are merged to the initial node features :obj:`x` for a subset of their nodes according to :obj:`mask`. .. note:: For an example of using :class:`MaskLabel`, see `examples/ < pytorch_geometric/blob/master/examples/>`_. Args: num_classes (int): The number of classes. out_channels (int): Size of each output sample. method (str, optional): If set to :obj:`"add"`, label embeddings are added to the input. If set to :obj:`"concat"`, label embeddings are concatenated. In case :obj:`method="add"`, then :obj:`out_channels` needs to be identical to the input dimensionality of node features. (default: :obj:`"add"`) """ def __init__(self, num_classes: int, out_channels: int, method: str = "add"): super().__init__() self.method = method if method not in ["add", "concat"]: raise ValueError( f"'method' must be either 'add' or 'concat' (got '{method}')") self.emb = torch.nn.Embedding(num_classes, out_channels) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.emb.reset_parameters()
[docs] def forward(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor: """""" # noqa: D419 if self.method == "concat": out = x.new_zeros(y.size(0), self.emb.weight.size(-1)) out[mask] = self.emb(y[mask]) return[x, out], dim=-1) else: x = torch.clone(x) x[mask] += self.emb(y[mask]) return x
[docs] @staticmethod def ratio_mask(mask: Tensor, ratio: float): r"""Modifies :obj:`mask` by setting :obj:`ratio` of :obj:`True` entries to :obj:`False`. Does not operate in-place. Args: mask (torch.Tensor): The mask to re-mask. ratio (float): The ratio of entries to keep. """ n = int(mask.sum()) out = mask.clone() out[mask] = torch.rand(n, device=mask.device) < ratio return out
def __repr__(self) -> str: return f'{self.__class__.__name__}()'