import torch
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_undirected
[docs]@functional_transform('face_to_edge')
class FaceToEdge(BaseTransform):
r"""Converts mesh faces :obj:`[3, num_faces]` to edge indices
:obj:`[2, num_edges]` (functional name: :obj:`face_to_edge`).
Args:
remove_faces (bool, optional): If set to :obj:`False`, the face tensor
will not be removed.
"""
def __init__(self, remove_faces: bool = True):
self.remove_faces = remove_faces
def forward(self, data: Data) -> Data:
if hasattr(data, 'face'):
face = data.face
edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)
edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)
data.edge_index = edge_index
if self.remove_faces:
data.face = None
return data