from typing import Optional
import torch
from torch import Tensor
[docs]def index_to_mask(index: Tensor, size: Optional[int] = None) -> Tensor:
r"""Converts indices to a mask representation.
Args:
idx (Tensor): The indices.
size (int, optional). The size of the mask. If set to :obj:`None`, a
minimal sized output mask is returned.
"""
index = index.view(-1)
size = int(index.max()) + 1 if size is None else size
mask = index.new_zeros(size, dtype=torch.bool)
mask[index] = True
return mask
[docs]def mask_to_index(mask: Tensor) -> Tensor:
r"""Converts a mask to an index representation.
Args:
mask (Tensor): The mask.
"""
return mask.nonzero(as_tuple=False).view(-1)