Source code for torch_geometric.transforms.grid_sampling

import re

import torch
import torch.nn.functional as F
from torch_scatter import scatter_add, scatter_mean

import torch_geometric
from torch_geometric.transforms import BaseTransform


[docs]class GridSampling(BaseTransform): r"""Clusters points into voxels with size :attr:`size`. Args: size (float or [float] or Tensor): Size of a voxel (in each dimension). start (float or [float] or Tensor, optional): Start coordinates of the grid (in each dimension). If set to :obj:`None`, will be set to the minimum coordinates found in :obj:`data.pos`. (default: :obj:`None`) end (float or [float] or Tensor, optional): End coordinates of the grid (in each dimension). If set to :obj:`None`, will be set to the maximum coordinates found in :obj:`data.pos`. (default: :obj:`None`) """ def __init__(self, size, start=None, end=None): self.size = size self.start = start self.end = end def __call__(self, data): num_nodes = data.num_nodes if 'batch' not in data: batch = data.pos.new_zeros(num_nodes, dtype=torch.long) else: batch = data.batch c = torch_geometric.nn.voxel_grid(data.pos, batch, self.size, self.start, self.end) c, perm = torch_geometric.nn.pool.consecutive.consecutive_cluster(c) for key, item in data: if bool(re.search('edge', key)): raise ValueError( 'GridSampling does not support coarsening of edges') if torch.is_tensor(item) and item.size(0) == num_nodes: if key == 'y': item = F.one_hot(item) item = scatter_add(item, c, dim=0) data[key] = item.argmax(dim=-1) elif key == 'batch': data[key] = item[perm] else: data[key] = scatter_mean(item, c, dim=0) return data def __repr__(self): return '{}(size={})'.format(self.__class__.__name__, self.size)