Source code for torch_geometric.nn.pool.voxel_grid

from typing import List, Optional, Union

import torch
from torch import Tensor

from torch_geometric.utils.repeat import repeat

try:
    from torch_cluster import grid_cluster
except ImportError:
    grid_cluster = None


[docs]def voxel_grid( pos: Tensor, size: Union[float, List[float], Tensor], batch: Optional[Tensor] = None, start: Optional[Union[float, List[float], Tensor]] = None, end: Optional[Union[float, List[float], Tensor]] = None, ) -> Tensor: r"""Voxel grid pooling from the, *e.g.*, `Dynamic Edge-Conditioned Filters in Convolutional Networks on Graphs <https://arxiv.org/abs/1704.02901>`_ paper, which overlays a regular grid of user-defined size over a point cloud and clusters all points within the same voxel. Args: pos (torch.Tensor): Node position matrix :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times D}`. size (float or [float] or Tensor): Size of a voxel (in each dimension). batch (torch.Tensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) start (float or [float] or Tensor, optional): Start coordinates of the grid (in each dimension). If set to :obj:`None`, will be set to the minimum coordinates found in :attr:`pos`. (default: :obj:`None`) end (float or [float] or Tensor, optional): End coordinates of the grid (in each dimension). If set to :obj:`None`, will be set to the maximum coordinates found in :attr:`pos`. (default: :obj:`None`) :rtype: :class:`torch.Tensor` """ if grid_cluster is None: raise ImportError('`voxel_grid` requires `torch-cluster`.') pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos dim = pos.size(1) if batch is None: batch = pos.new_zeros(pos.size(0), dtype=torch.long) pos = torch.cat([pos, batch.view(-1, 1).to(pos.dtype)], dim=-1) if not isinstance(size, Tensor): size = torch.tensor(size, dtype=pos.dtype, device=pos.device) size = repeat(size, dim) size = torch.cat([size, size.new_ones(1)]) # Add additional batch dim. if start is not None: if not isinstance(start, Tensor): start = torch.tensor(start, dtype=pos.dtype, device=pos.device) start = repeat(start, dim) start = torch.cat([start, start.new_zeros(1)]) if end is not None: if not isinstance(end, Tensor): end = torch.tensor(end, dtype=pos.dtype, device=pos.device) end = repeat(end, dim) end = torch.cat([end, batch.max().unsqueeze(0)]) return grid_cluster(pos, size, start, end)