Source code for torch_geometric.datasets.shapenet

import os
import os.path as osp
import shutil
import json

import torch

from import (Data, InMemoryDataset, download_url,
from import read_txt_array

[docs]class ShapeNet(InMemoryDataset): r"""The ShapeNet part level segmentation dataset from the `"A Scalable Active Framework for Region Annotation in 3D Shape Collections" <>`_ paper, containing about 17,000 3D shape point clouds from 16 shape categories. Each category is annotated with 2 to 6 parts. Args: root (string): Root directory where the dataset should be saved. categories (string or [string], optional): The category of the CAD models (one or a combination of :obj:`"Airplane"`, :obj:`"Bag"`, :obj:`"Cap"`, :obj:`"Car"`, :obj:`"Chair"`, :obj:`"Earphone"`, :obj:`"Guitar"`, :obj:`"Knife"`, :obj:`"Lamp"`, :obj:`"Laptop"`, :obj:`"Motorbike"`, :obj:`"Mug"`, :obj:`"Pistol"`, :obj:`"Rocket"`, :obj:`"Skateboard"`, :obj:`"Table"`). Can be explicitly set to :obj:`None` to load all categories. (default: :obj:`None`) include_normals (bool, optional): If set to :obj:`False`, will not include normal vectors as input features. (default: :obj:`True`) split (string, optional): If :obj:`"train"`, loads the training dataset. If :obj:`"val"`, loads the validation dataset. If :obj:`"trainval"`, loads the training and validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"trainval"`) transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ url = ('' '') category_ids = { 'Airplane': '02691156', 'Bag': '02773838', 'Cap': '02954340', 'Car': '02958343', 'Chair': '03001627', 'Earphone': '03261776', 'Guitar': '03467517', 'Knife': '03624134', 'Lamp': '03636649', 'Laptop': '03642806', 'Motorbike': '03790512', 'Mug': '03797390', 'Pistol': '03948459', 'Rocket': '04099429', 'Skateboard': '04225987', 'Table': '04379243', } seg_classes = { 'Airplane': [0, 1, 2, 3], 'Bag': [4, 5], 'Cap': [6, 7], 'Car': [8, 9, 10, 11], 'Chair': [12, 13, 14, 15], 'Earphone': [16, 17, 18], 'Guitar': [19, 20, 21], 'Knife': [22, 23], 'Lamp': [24, 25, 26, 27], 'Laptop': [28, 29], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Mug': [36, 37], 'Pistol': [38, 39, 40], 'Rocket': [41, 42, 43], 'Skateboard': [44, 45, 46], 'Table': [47, 48, 49], } def __init__(self, root, categories=None, include_normals=True, split='trainval', transform=None, pre_transform=None, pre_filter=None): if categories is None: categories = list(self.category_ids.keys()) if isinstance(categories, str): categories = [categories] assert all(category in self.category_ids for category in categories) self.categories = categories super(ShapeNet, self).__init__(root, transform, pre_transform, pre_filter) if split == 'train': path = self.processed_paths[0] elif split == 'val': path = self.processed_paths[1] elif split == 'test': path = self.processed_paths[2] elif split == 'trainval': path = self.processed_paths[3] else: raise ValueError((f'Split {split} found, but expected either ' 'train, val, trainval or test')), self.slices = torch.load(path) = if include_normals else None self.y_mask = torch.zeros((len(self.seg_classes.keys()), 50), dtype=torch.bool) for i, labels in enumerate(self.seg_classes.values()): self.y_mask[i, labels] = 1 @property def num_classes(self): return self.y_mask.size(-1) @property def raw_file_names(self): return list(self.category_ids.values()) + ['train_test_split'] @property def processed_file_names(self): cats = '_'.join([cat[:3].lower() for cat in self.categories]) return [ os.path.join('{}_{}.pt'.format(cats, split)) for split in ['train', 'val', 'test', 'trainval'] ] def download(self): path = download_url(self.url, self.root) extract_zip(path, self.root) os.unlink(path) shutil.rmtree(self.raw_dir) name = self.url.split('/')[-1].split('.')[0] os.rename(osp.join(self.root, name), self.raw_dir) def process_filenames(self, filenames): data_list = [] categories_ids = [self.category_ids[cat] for cat in self.categories] cat_idx = {categories_ids[i]: i for i in range(len(categories_ids))} for name in filenames: cat = name.split(osp.sep)[0] if cat not in categories_ids: continue data = read_txt_array(osp.join(self.raw_dir, name)) pos = data[:, :3] x = data[:, 3:6] y = data[:, -1].type(torch.long) data = Data(pos=pos, x=x, y=y, category=cat_idx[cat]) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) return data_list def process(self): trainval = [] for i, split in enumerate(['train', 'val', 'test']): path = osp.join(self.raw_dir, 'train_test_split', f'shuffled_{split}_file_list.json') with open(path, 'r') as f: filenames = [ osp.sep.join(name.split('/')[1:]) + '.txt' for name in json.load(f) ] # Removing first directory. data_list = self.process_filenames(filenames) if split == 'train' or split == 'val': trainval += data_list, self.processed_paths[i]), self.processed_paths[3]) def __repr__(self): return '{}({}, categories={})'.format(self.__class__.__name__, len(self), self.categories)