Source code for torch_geometric.datasets.hm

from typing import Callable, List, Optional

import torch

from torch_geometric.data import HeteroData, InMemoryDataset


[docs]class HM(InMemoryDataset): r"""The heterogeneous H&M dataset from the `Kaggle H&M Personalized Fashion Recommendations <https://www.kaggle.com/competitions/h-and-m-personalized-fashion-recommendations>`_ challenge. The task is to develop product recommendations based on data from previous transactions, as well as from customer and product meta data. Args: root (str): Root directory where the dataset should be saved. use_all_tables_as_node_types (bool, optional): If set to :obj:`True`, will use the transaction table as a distinct node type. (default: :obj:`False`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = ('https://www.kaggle.com/competitions/' 'h-and-m-personalized-fashion-recommendations/data') def __init__( self, root: str, use_all_tables_as_node_types: bool = False, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.use_all_tables_as_node_types = use_all_tables_as_node_types super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return [ 'customers.csv.zip', 'articles.csv.zip', 'transactions_train.csv.zip' ] @property def processed_file_names(self) -> str: if self.use_all_tables_as_node_types: return 'data.pt' else: return 'data_merged.pt' def download(self) -> None: raise RuntimeError( f"Dataset not found. Please download {self.raw_file_names} from " f"'{self.url}' and move it to '{self.raw_dir}'") def process(self) -> None: import pandas as pd data = HeteroData() # Process customer data ############################################### df = pd.read_csv(self.raw_paths[0], index_col='customer_id') customer_map = {idx: i for i, idx in enumerate(df.index)} xs = [] for name in [ 'Active', 'FN', 'club_member_status', 'fashion_news_frequency' ]: x = pd.get_dummies(df[name]).values xs.append(torch.from_numpy(x).to(torch.float)) x = torch.from_numpy(df['age'].values).to(torch.float).view(-1, 1) x = x.nan_to_num(nan=x.nanmean()) xs.append(x / x.max()) data['customer'].x = torch.cat(xs, dim=-1) # Process article data ################################################ df = pd.read_csv(self.raw_paths[1], index_col='article_id') article_map = {idx: i for i, idx in enumerate(df.index)} xs = [] for name in [ # We drop a few columns here that are high cardinality. # 'product_code', # Drop. # 'prod_name', # Drop. 'product_type_no', 'product_type_name', 'product_group_name', 'graphical_appearance_no', 'graphical_appearance_name', 'colour_group_code', 'colour_group_name', 'perceived_colour_value_id', 'perceived_colour_value_name', 'perceived_colour_master_id', 'perceived_colour_master_name', # 'department_no', # Drop. # 'department_name', # Drop. 'index_code', 'index_name', 'index_group_no', 'index_group_name', 'section_no', 'section_name', 'garment_group_no', 'garment_group_name', # 'detail_desc', # Drop. ]: x = pd.get_dummies(df[name]).values xs.append(torch.from_numpy(x).to(torch.float)) data['article'].x = torch.cat(xs, dim=-1) # Process transaction data ############################################ df = pd.read_csv(self.raw_paths[2], parse_dates=['t_dat']) x1 = pd.get_dummies(df['sales_channel_id']).values x1 = torch.from_numpy(x1).to(torch.float) x2 = torch.from_numpy(df['price'].values).to(torch.float).view(-1, 1) x = torch.cat([x1, x2], dim=-1) time = torch.from_numpy(df['t_dat'].values.astype(int)) time = time // (60 * 60 * 24 * 10**9) # Convert nanoseconds to days. src = torch.tensor([customer_map[idx] for idx in df['customer_id']]) dst = torch.tensor([article_map[idx] for idx in df['article_id']]) if self.use_all_tables_as_node_types: data['transaction'].x = x data['transaction'].time = time edge_index = torch.stack([src, torch.arange(len(df))], dim=0) data['customer', 'to', 'transaction'].edge_index = edge_index edge_index = edge_index.flip([0]) data['transaction', 'rev_to', 'customer'].edge_index = edge_index edge_index = torch.stack([dst, torch.arange(len(df))], dim=0) data['article', 'to', 'transaction'].edge_index = edge_index edge_index = edge_index.flip([0]) data['transaction', 'rev_to', 'article'].edge_index = edge_index else: edge_index = torch.stack([src, dst], dim=0) data['customer', 'to', 'article'].edge_index = edge_index data['customer', 'to', 'article'].time = time data['customer', 'to', 'article'].edge_attr = x edge_index = edge_index.flip([0]) data['article', 'rev_to', 'customer'].edge_index = edge_index data['article', 'rev_to', 'customer'].time = time data['article', 'rev_to', 'customer'].edge_attr = x if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0])