Source code for torch_geometric.utils.mask

from typing import Optional

import torch
from torch import Tensor

from torch_geometric.typing import TensorFrame

[docs]def mask_select(src: Tensor, dim: int, mask: Tensor) -> Tensor: r"""Returns a new tensor which masks the :obj:`src` tensor along the dimension :obj:`dim` according to the boolean mask :obj:`mask`. Args: src (torch.Tensor): The input tensor. dim (int): The dimension in which to mask. mask (torch.BoolTensor): The 1-D tensor containing the binary mask to index with. """ assert mask.dim() == 1 if not torch.jit.is_scripting(): if isinstance(src, TensorFrame): assert dim == 0 and src.num_rows == mask.numel() return src[mask] assert src.size(dim) == mask.numel() dim = dim + src.dim() if dim < 0 else dim assert dim >= 0 and dim < src.dim() # Applying a 1-dimensional mask in the first dimension is significantly # faster than broadcasting the mask and utilizing `masked_select`. # As such, we transpose in the first dimension, perform the masking, and # then transpose back to the original shape. src = src.transpose(0, dim) if dim != 0 else src out = src[mask] out = out.transpose(0, dim) if dim != 0 else out return out
[docs]def index_to_mask(index: Tensor, size: Optional[int] = None) -> Tensor: r"""Converts indices to a mask representation. Args: index (Tensor): The indices. size (int, optional): The size of the mask. If set to :obj:`None`, a minimal sized output mask is returned. Example: >>> index = torch.tensor([1, 3, 5]) >>> index_to_mask(index) tensor([False, True, False, True, False, True]) >>> index_to_mask(index, size=7) tensor([False, True, False, True, False, True, False]) """ index = index.view(-1) size = int(index.max()) + 1 if size is None else size mask = index.new_zeros(size, dtype=torch.bool) mask[index] = True return mask
[docs]def mask_to_index(mask: Tensor) -> Tensor: r"""Converts a mask to an index representation. Args: mask (Tensor): The mask. Example: >>> mask = torch.tensor([False, True, False]) >>> mask_to_index(mask) tensor([1]) """ return mask.nonzero(as_tuple=False).view(-1)