Source code for torch_geometric.transforms.delaunay

from typing import List

import torch

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform


class _QhullTransform(BaseTransform):
    r"""Q-hull implementation of delaunay triangulation."""
    def forward(self, data: Data) -> Data:
        assert data.pos is not None
        import scipy.spatial

        pos = data.pos.cpu().numpy()
        tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
        face = torch.from_numpy(tri.simplices)

        data.face = face.t().contiguous().to(data.pos.device, torch.long)
        return data


class _ShullTransform(BaseTransform):
    r"""Sweep-hull implementation of delaunay triangulation."""
    def forward(self, data: Data) -> Data:
        assert data.pos is not None
        from torch_delaunay.functional import shull2d

        face = shull2d(data.pos.cpu())
        data.face = face.t().contiguous().to(data.pos.device)
        return data


class _SequentialTransform(BaseTransform):
    r"""Runs the first successful transformation.

    All intermediate exceptions are suppressed except the last.
    """
    def __init__(self, transforms: List[BaseTransform]) -> None:
        assert len(transforms) > 0
        self.transforms = transforms

    def forward(self, data: Data) -> Data:
        for i, transform in enumerate(self.transforms):
            try:
                return transform.forward(data)
            except ImportError as e:
                if i == len(self.transforms) - 1:
                    raise e
        return data


[docs]@functional_transform('delaunay') class Delaunay(BaseTransform): r"""Computes the delaunay triangulation of a set of points (functional name: :obj:`delaunay`). .. hint:: Consider installing the `torch_delaunay <https://github.com/ybubnov/torch_delaunay>`_ package to speed up computation. """ def __init__(self) -> None: self._transform = _SequentialTransform([ _ShullTransform(), _QhullTransform(), ]) def forward(self, data: Data) -> Data: assert data.pos is not None device = data.pos.device if data.pos.size(0) < 2: data.edge_index = torch.empty(2, 0, dtype=torch.long, device=device) elif data.pos.size(0) == 2: data.edge_index = torch.tensor([[0, 1], [1, 0]], device=device) elif data.pos.size(0) == 3: data.face = torch.tensor([[0], [1], [2]], device=device) else: data = self._transform.forward(data) return data