Source code for torch_geometric.utils.grid

from typing import Optional, Tuple

import torch
from torch import Tensor

from torch_geometric.utils.coalesce import coalesce


[docs]def grid( height: int, width: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> Tuple[Tensor, Tensor]: r"""Returns the edge indices of a two-dimensional grid graph with height :attr:`height` and width :attr:`width` and its node positions. Args: height (int): The height of the grid. width (int): The width of the grid. dtype (:obj:`torch.dtype`, optional): The desired data type of the returned position tensor. dtype (:obj:`torch.device`, optional): The desired device of the returned tensors. :rtype: (:class:`LongTensor`, :class:`Tensor`) Example: >>> (row, col), pos = grid(height=2, width=2) >>> row tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]) >>> col tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]) >>> pos tensor([[0., 1.], [1., 1.], [0., 0.], [1., 0.]]) """ edge_index = grid_index(height, width, device) pos = grid_pos(height, width, dtype, device) return edge_index, pos
def grid_index( height: int, width: int, device: Optional[torch.device] = None, ) -> Tensor: w = width kernel = [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1] kernel = torch.tensor(kernel, device=device) row = torch.arange(height * width, dtype=torch.long, device=device) row = row.view(-1, 1).repeat(1, kernel.size(0)) col = row + kernel.view(1, -1) row, col = row.view(height, -1), col.view(height, -1) index = torch.arange(3, row.size(1) - 3, dtype=torch.long, device=device) row, col = row[:, index].view(-1), col[:, index].view(-1) mask = (col >= 0) & (col < height * width) row, col = row[mask], col[mask] edge_index = torch.stack([row, col], dim=0) edge_index = coalesce(edge_index, None, height * width) return edge_index def grid_pos( height: int, width: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> Tensor: dtype = torch.float if dtype is None else dtype x = torch.arange(width, dtype=dtype, device=device) y = (height - 1) - torch.arange(height, dtype=dtype, device=device) x = x.repeat(height) y = y.unsqueeze(-1).repeat(1, width).view(-1) return torch.stack([x, y], dim=-1)