Source code for torch_geometric.datasets.coma

import os.path as osp
from glob import glob
from typing import Callable, List, Optional

import torch

from import InMemoryDataset, extract_zip
from import read_ply

[docs]class CoMA(InMemoryDataset): r"""The CoMA 3D faces dataset from the `"Generating 3D faces using Convolutional Mesh Autoencoders" <>`_ paper, containing 20,466 meshes of extreme expressions captured over 12 different subjects. .. note:: Data objects hold mesh faces instead of edge indices. To convert the mesh to a graph, use the :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`. To convert the mesh to a point cloud, use the :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to sample a fixed number of points on the mesh faces according to their face area. Args: root (str): Root directory where the dataset should be saved. train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 10 10 10 10 10 :header-rows: 1 * - #graphs - #nodes - #edges - #features - #classes * - 20,465 - 5,023 - 29,990 - 3 - 12 """ url = '' categories = [ 'bareteeth', 'cheeks_in', 'eyebrow', 'high_smile', 'lips_back', 'lips_up', 'mouth_down', 'mouth_extreme', 'mouth_middle', 'mouth_open', 'mouth_side', 'mouth_up', ] def __init__( self, root: str, train: bool = True, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[0] if train else self.processed_paths[1] self.load(path) @property def raw_file_names(self) -> str: return '' @property def processed_file_names(self) -> List[str]: return ['', ''] def download(self) -> None: raise RuntimeError( f"Dataset not found. Please download '' from " f"'{self.url}' and move it to '{self.raw_dir}'") def process(self) -> None: folders = sorted(glob(osp.join(self.raw_dir, 'FaceTalk_*'))) if len(folders) == 0: extract_zip(self.raw_paths[0], self.raw_dir, log=False) folders = sorted(glob(osp.join(self.raw_dir, 'FaceTalk_*'))) train_data_list, test_data_list = [], [] for folder in folders: for i, category in enumerate(self.categories): files = sorted(glob(osp.join(folder, category, '*.ply'))) for j, f in enumerate(files): data = read_ply(f) data.y = torch.tensor([i], dtype=torch.long) if self.pre_filter is not None and\ not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) if (j % 100) < 90: train_data_list.append(data) else: test_data_list.append(data), self.processed_paths[0]), self.processed_paths[1])