Source code for torch_geometric.datasets.movie_lens_1m

import os
import os.path as osp
from typing import Callable, List, Optional

import torch

from torch_geometric.data import (
    HeteroData,
    InMemoryDataset,
    download_url,
    extract_zip,
)
from torch_geometric.io import fs

MOVIE_HEADERS = ["movieId", "title", "genres"]
USER_HEADERS = ["userId", "gender", "age", "occupation", "zipCode"]
RATING_HEADERS = ['userId', 'movieId', 'rating', 'timestamp']


[docs]class MovieLens1M(InMemoryDataset): r"""The MovieLens 1M heterogeneous rating dataset, assembled by GroupLens Research from the `MovieLens web site <https://movielens.org>`__, consisting of movies (3,883 nodes) and users (6,040 nodes) with approximately 1 million ratings between them. User ratings for movies are available as ground truth labels. Features of users and movies are encoded according to the `"Inductive Matrix Completion Based on Graph Neural Networks" <https://arxiv.org/abs/1904.12058>`__ paper. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) **STATS:** .. list-table:: :widths: 20 10 10 10 :header-rows: 1 * - Node/Edge Type - #nodes/#edges - #features - #tasks * - Movie - 3,883 - 18 - * - User - 6,040 - 30 - * - User-Movie - 1,000,209 - 1 - 1 """ url = 'https://files.grouplens.org/datasets/movielens/ml-1m.zip' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, force_reload: bool = False, ) -> None: super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return ['movies.dat', 'users.dat', 'ratings.dat'] @property def processed_file_names(self) -> str: return 'data.pt' def download(self) -> None: path = download_url(self.url, self.root) extract_zip(path, self.root) os.remove(path) folder = osp.join(self.root, 'ml-1m') fs.rm(self.raw_dir) os.rename(folder, self.raw_dir) def process(self) -> None: import pandas as pd data = HeteroData() # Process movie data: df = pd.read_csv( self.raw_paths[0], sep='::', header=None, index_col='movieId', names=MOVIE_HEADERS, encoding='ISO-8859-1', engine='python', ) movie_mapping = {idx: i for i, idx in enumerate(df.index)} genres = df['genres'].str.get_dummies('|').values genres = torch.from_numpy(genres).to(torch.float) data['movie'].x = genres # Process user data: df = pd.read_csv( self.raw_paths[1], sep='::', header=None, index_col='userId', names=USER_HEADERS, dtype='str', encoding='ISO-8859-1', engine='python', ) user_mapping = {idx: i for i, idx in enumerate(df.index)} age = df['age'].str.get_dummies().values age = torch.from_numpy(age).to(torch.float) gender = df['gender'].str.get_dummies().values gender = torch.from_numpy(gender).to(torch.float) occupation = df['occupation'].str.get_dummies().values occupation = torch.from_numpy(occupation).to(torch.float) data['user'].x = torch.cat([age, gender, occupation], dim=-1) # Process rating data: df = pd.read_csv( self.raw_paths[2], sep='::', header=None, names=RATING_HEADERS, encoding='ISO-8859-1', engine='python', ) src = [user_mapping[idx] for idx in df['userId']] dst = [movie_mapping[idx] for idx in df['movieId']] edge_index = torch.tensor([src, dst]) data['user', 'rates', 'movie'].edge_index = edge_index rating = torch.from_numpy(df['rating'].values).to(torch.long) data['user', 'rates', 'movie'].rating = rating time = torch.from_numpy(df['timestamp'].values) data['user', 'rates', 'movie'].time = time data['movie', 'rated_by', 'user'].edge_index = edge_index.flip([0]) data['movie', 'rated_by', 'user'].rating = rating data['movie', 'rated_by', 'user'].time = time if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0])