Source code for torch_geometric.datasets.dynamic_faust

from itertools import product

import h5py
import torch
from torch_geometric.data import InMemoryDataset, Data


[docs]class DynamicFAUST(InMemoryDataset): r"""The dynamic FAUST humans dataset from the `"Dynamic FAUST: Registering Human Bodies in Motion" <http://files.is.tue.mpg.de/black/papers/dfaust2017.pdf>`_ paper. .. note:: Data objects hold mesh faces instead of edge indices. To convert the mesh to a graph, use the :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`. To convert the mesh to a point cloud, use the :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to sample a fixed number of points on the mesh faces according to their face area. Args: root (string): Root directory where the dataset should be saved. subjects (list, optional): List of subjects to include in the dataset. Can include the subjects :obj:`"50002"`, :obj:`"50004"`, :obj:`"50007"`, :obj:`"50009"`, :obj:`"50020"`, :obj:`"50021"`, :obj:`"50022"`, :obj:`"50025"`, :obj:`"50026"`, :obj:`"50027"`. If set to :obj:`None`, the dataset will contain all subjects. (default: :obj:`None`) categories (list, optional): List of categories to include in the dataset. Can include the categories :obj:`"chicken_wings"`, :obj:`"hips"`, :obj:`"jiggle_on_toes"`, :obj:`"jumping_jacks"`, :obj:`"knees"`, :obj:`"light_hopping_loose"`, :obj:`"light_hopping_stiff"`, :obj:`"one_leg_jump"`, :obj:`"one_leg_loose"`, :obj:`"personal_move"`, :obj:`"punching"`, :obj:`"running_on_spot"`, :obj:`"running_on_spot_bugfix"`, :obj:`"shake_arms"`, :obj:`"shake_hips"`, :obj:`"shoulders"`. If set to :obj:`None`, the dataset will contain all categories. (default: :obj:`None`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ url = 'http://dfaust.is.tue.mpg.de/' subjects = [ '50002', '50004', '50007', '50009', '50020', '50021', '50022', '50025', '50026', '50027' ] categories = [ 'chicken_wings', 'hips', 'jiggle_on_toes', 'jumping_jacks', 'knees', 'light_hopping_loose', 'light_hopping_stiff', 'one_leg_jump', 'one_leg_loose', 'personal_move', 'punching', 'running_on_spot', 'running_on_spot_bugfix', 'shake_arms', 'shake_hips', 'shake_shoulders' ] def __init__(self, root, subjects=None, categories=None, transform=None, pre_transform=None, pre_filter=None): subjects = self.subjects if subjects is None else subjects subjects = [sid.lower() for sid in subjects] for sid in subjects: assert sid in self.subjects self.subjects = subjects categories = self.categories if categories is None else categories categories = [cat.lower() for cat in categories] for cat in categories: assert cat in self.categories self.categories = categories super(DynamicFAUST, self).__init__(root, transform, pre_transform, pre_filter) self.data, self.slices = torch.load(self.processed_paths[0]) @property def raw_file_names(self): return ['registrations_m.hdf5', 'registrations_f.hdf5'] @property def processed_file_names(self): sids = '_'.join([sid[-2:] for sid in self.subjects]) cats = '_'.join([ ''.join([w[0] for w in cat.split('_')]) for cat in self.categories ]) return '{}_{}.pt'.format(sids, cats) def download(self): raise RuntimeError( 'Dataset not found. Please download male registrations ' '(registrations_m.hdf5) and female registrations ' '(registrations_f.hdf5) from {} and ' 'move it to {}'.format(self.url, self.raw_dir)) def process(self): fm = h5py.File(self.raw_paths[0], 'r') ff = h5py.File(self.raw_paths[1], 'r') face = torch.from_numpy(fm['faces'][()]).to(torch.long) face = face.t().contiguous() data_list = [] for (sid, cat) in product(self.subjects, self.categories): idx = '{}_{}'.format(sid, cat) if idx in fm: pos = torch.from_numpy(fm[idx][()]) elif idx in ff: pos = torch.from_numpy(ff[idx][()]) else: continue pos = pos.permute(2, 0, 1).contiguous() data_list.append(Data(pos=pos, face=face)) if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] torch.save(self.collate(data_list), self.processed_paths[0])