Source code for torch_geometric.datasets.willow_object_class

import glob
import os
import os.path as osp
from typing import Callable, List, Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)
from torch_geometric.io import fs


[docs]class WILLOWObjectClass(InMemoryDataset): r"""The WILLOW-ObjectClass dataset from the `"Learning Graphs to Match" <https://www.di.ens.fr/willow/pdfscurrent/cho2013.pdf>`_ paper, containing 10 equal keypoints of at least 40 images in each category. The keypoints contain interpolated features from a pre-trained VGG16 model on ImageNet (:obj:`relu4_2` and :obj:`relu5_1`). Args: root (str): Root directory where the dataset should be saved. category (str): The category of the images (one of :obj:`"Car"`, :obj:`"Duck"`, :obj:`"Face"`, :obj:`"Motorbike"`, :obj:`"Winebottle"`). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) device (str or torch.device, optional): The device to use for processing the raw data. If set to :obj:`None`, will utilize GPU-processing if available. (default: :obj:`None`) """ url = ('http://www.di.ens.fr/willow/research/graphlearning/' 'WILLOW-ObjectClass_dataset.zip') categories = ['face', 'motorbike', 'car', 'duck', 'winebottle'] batch_size = 32 def __init__( self, root: str, category: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, device: Optional[str] = None, ) -> None: if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' assert category.lower() in self.categories self.category = category self.device = device super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.category.capitalize(), 'processed') @property def raw_file_names(self) -> List[str]: return [category.capitalize() for category in self.categories] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.root) extract_zip(path, self.root) os.unlink(path) os.unlink(osp.join(self.root, 'README')) os.unlink(osp.join(self.root, 'demo_showAnno.m')) fs.rm(self.raw_dir) os.rename(osp.join(self.root, 'WILLOW-ObjectClass'), self.raw_dir) def process(self) -> None: import torchvision.models as models import torchvision.transforms as T from PIL import Image from scipy.io import loadmat category = self.category.capitalize() names = glob.glob(osp.join(self.raw_dir, category, '*.png')) names = sorted([name[:-4] for name in names]) vgg16_outputs = [] def hook(module: torch.nn.Module, x: Tensor, y: Tensor) -> None: vgg16_outputs.append(y.to('cpu')) vgg16 = models.vgg16(pretrained=True).to(self.device) vgg16.eval() vgg16.features[20].register_forward_hook(hook) # relu4_2 vgg16.features[25].register_forward_hook(hook) # relu5_1 transform = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) data_list = [] for name in names: pos = loadmat(f'{name}.mat')['pts_coord'] x, y = torch.from_numpy(pos).to(torch.float) pos = torch.stack([x, y], dim=1) # The "face" category contains a single image with less than 10 # keypoints, so we need to skip it. if pos.size(0) != 10: continue with open(f'{name}.png', 'rb') as f: img = Image.open(f).convert('RGB') # Rescale keypoints. pos[:, 0] = pos[:, 0] * 256.0 / (img.size[0]) pos[:, 1] = pos[:, 1] * 256.0 / (img.size[1]) img = img.resize((256, 256), resample=Image.Resampling.BICUBIC) img = transform(img) data = Data(img=img, pos=pos, name=name) data_list.append(data) imgs = [data.img for data in data_list] loader = DataLoader( dataset=imgs, # type: ignore batch_size=self.batch_size, shuffle=False, ) for i, batch_img in enumerate(loader): vgg16_outputs.clear() with torch.no_grad(): vgg16(batch_img.to(self.device)) out1 = F.interpolate(vgg16_outputs[0], (256, 256), mode='bilinear', align_corners=False) out2 = F.interpolate(vgg16_outputs[1], (256, 256), mode='bilinear', align_corners=False) for j in range(out1.size(0)): data = data_list[i * self.batch_size + j] assert data.pos is not None idx = data.pos.round().long().clamp(0, 255) x_1 = out1[j, :, idx[:, 1], idx[:, 0]].to('cpu') x_2 = out2[j, :, idx[:, 1], idx[:, 0]].to('cpu') data.img = None data.x = torch.cat([x_1.t(), x_2.t()], dim=-1) del out1 del out2 if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] self.save(data_list, self.processed_paths[0]) def __repr__(self) -> str: return (f'{self.__class__.__name__}({len(self)}, ' f'category={self.category})')