Source code for torch_geometric.datasets.shapenet

import os
import os.path as osp
import glob

import torch
from torch_geometric.data import (Data, InMemoryDataset, download_url,
                                  extract_zip)
from torch_geometric.read import read_txt_array


[docs]class ShapeNet(InMemoryDataset): r""" The ShapeNet part level segmentation dataset from the `"A Scalable Active Framework for Region Annotation in 3D Shape Collections" <http://web.stanford.edu/~ericyi/papers/part_annotation_16_small.pdf>`_ paper, containing about 17,000 3D shape point clouds from 16 shape categories. Each category is annotated with 2 to 6 parts. Args: root (string): Root directory where the dataset should be saved. categories (string or [string], optional): The category of the CAD models (one or a combination of :obj:`"Airplane"`, :obj:`"Bag"`, :obj:`"Cap"`, :obj:`"Car"`, :obj:`"Chair"`, :obj:`"Earphone"`, :obj:`"Guitar"`, :obj:`"Knife"`, :obj:`"Lamp"`, :obj:`"Laptop"`, :obj:`"Motorbike"`, :obj:`"Mug"`, :obj:`"Pistol"`, :obj:`"Rocket"`, :obj:`"Skateboard"`, :obj:`"Table"`). Can be explicitly set to :obj:`None` to load all categories. (default: :obj:`None`) train (bool, optional): If :obj:`True`, loads the training dataset, otherwise the test dataset. (default: :obj:`True`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ url = 'https://shapenet.cs.stanford.edu/iccv17/partseg' category_ids = { 'Airplane': '02691156', 'Bag': '02773838', 'Cap': '02954340', 'Car': '02958343', 'Chair': '03001627', 'Earphone': '03261776', 'Guitar': '03467517', 'Knife': '03624134', 'Lamp': '03636649', 'Laptop': '03642806', 'Motorbike': '03790512', 'Mug': '03797390', 'Pistol': '03948459', 'Rocket': '04099429', 'Skateboard': '04225987', 'Table': '04379243', } def __init__(self, root, categories=None, train=True, transform=None, pre_transform=None, pre_filter=None): if categories is None: categories = list(self.category_ids.keys()) if isinstance(categories, str): categories = [categories] assert all(category in self.category_ids for category in categories) self.categories = categories super(ShapeNet, self).__init__(root, transform, pre_transform, pre_filter) path = self.processed_paths[0] if train else self.processed_paths[1] self.data, self.slices = torch.load(path) self.y_mask = torch.load(self.processed_paths[2]) @property def raw_file_names(self): return [ 'train_data', 'train_label', 'val_data', 'val_label', 'test_data', 'test_label' ] @property def processed_file_names(self): cats = '_'.join([cat[:3].lower() for cat in self.categories]) return [ '{}_{}.pt'.format(cats, s) for s in ['training', 'test', 'y_mask'] ] def download(self): for name in self.raw_file_names: url = '{}/{}.zip'.format(self.url, name) path = download_url(url, self.raw_dir) extract_zip(path, self.raw_dir) os.unlink(path) def process_raw_path(self, data_path, label_path): y_offset = 0 data_list = [] cat_ys = [] for cat_idx, cat in enumerate(self.categories): idx = self.category_ids[cat] point_paths = sorted(glob.glob(osp.join(data_path, idx, '*.pts'))) y_paths = sorted(glob.glob(osp.join(label_path, idx, '*.seg'))) points = [read_txt_array(path) for path in point_paths] ys = [read_txt_array(path, dtype=torch.long) for path in y_paths] lens = [y.size(0) for y in ys] y = torch.cat(ys).unique(return_inverse=True)[1] + y_offset cat_ys.append(y.unique()) y_offset = y.max().item() + 1 ys = y.split(lens) for (pos, y) in zip(points, ys): data = Data(y=y, pos=pos, category=cat_idx) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) y_mask = torch.zeros((len(self.categories), y_offset), dtype=torch.uint8) for i in range(len(cat_ys)): y_mask[i, cat_ys[i]] = 1 return data_list, y_mask def process(self): train_data_list, y_mask = self.process_raw_path(*self.raw_paths[0:2]) val_data_list, _ = self.process_raw_path(*self.raw_paths[2:4]) test_data_list, _ = self.process_raw_path(*self.raw_paths[4:6]) data = self.collate(train_data_list + val_data_list) torch.save(data, self.processed_paths[0]) torch.save(self.collate(test_data_list), self.processed_paths[1]) torch.save(y_mask, self.processed_paths[2]) def __repr__(self): return '{}({}, categories={})'.format(self.__class__.__name__, len(self), self.categories)