Source code for torch_geometric.utils.mask

from typing import Optional

import torch
from torch import Tensor


[docs]def index_to_mask(index: Tensor, size: Optional[int] = None) -> Tensor: r"""Converts indices to a mask representation. Args: idx (Tensor): The indices. size (int, optional). The size of the mask. If set to :obj:`None`, a minimal sized output mask is returned. """ index = index.view(-1) size = int(index.max()) + 1 if size is None else size mask = index.new_zeros(size, dtype=torch.bool) mask[index] = True return mask