Source code for torch_geometric.transforms.grid_sampling

import re
from typing import List, Optional, Union

import torch
from torch import Tensor

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import one_hot, scatter


[docs]@functional_transform('grid_sampling') class GridSampling(BaseTransform): r"""Clusters points into fixed-sized voxels (functional name: :obj:`grid_sampling`). Each cluster returned is a new point based on the mean of all points inside the given cluster. Args: size (float or [float] or Tensor): Size of a voxel (in each dimension). start (float or [float] or Tensor, optional): Start coordinates of the grid (in each dimension). If set to :obj:`None`, will be set to the minimum coordinates found in :obj:`data.pos`. (default: :obj:`None`) end (float or [float] or Tensor, optional): End coordinates of the grid (in each dimension). If set to :obj:`None`, will be set to the maximum coordinates found in :obj:`data.pos`. (default: :obj:`None`) """ def __init__( self, size: Union[float, List[float], Tensor], start: Optional[Union[float, List[float], Tensor]] = None, end: Optional[Union[float, List[float], Tensor]] = None, ) -> None: self.size = size self.start = start self.end = end def forward(self, data: Data) -> Data: num_nodes = data.num_nodes assert data.pos is not None c = torch_geometric.nn.voxel_grid(data.pos, self.size, data.batch, self.start, self.end) c, perm = torch_geometric.nn.pool.consecutive.consecutive_cluster(c) for key, item in data.items(): if bool(re.search('edge', key)): raise ValueError(f"'{self.__class__.__name__}' does not " f"support coarsening of edges") if torch.is_tensor(item) and item.size(0) == num_nodes: if key == 'y': item = scatter(one_hot(item), c, dim=0, reduce='sum') data[key] = item.argmax(dim=-1) elif key == 'batch': data[key] = item[perm] else: data[key] = scatter(item, c, dim=0, reduce='mean') return data def __repr__(self) -> str: return f'{self.__class__.__name__}(size={self.size})'