Source code for torch_geometric.datasets.movie_lens

import os
import os.path as osp
from typing import Callable, List, Optional

import torch

from torch_geometric.data import (
    HeteroData,
    InMemoryDataset,
    download_url,
    extract_zip,
)


[docs]class MovieLens(InMemoryDataset): r"""A heterogeneous rating dataset, assembled by GroupLens Research from the `MovieLens web site <https://movielens.org>`_, consisting of nodes of type :obj:`"movie"` and :obj:`"user"`. User ratings for movies are available as ground truth labels for the edges between the users and the movies :obj:`("user", "rates", "movie")`. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.HeteroData` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) model_name (str): Name of model used to transform movie titles to node features. The model comes from the`Huggingface SentenceTransformer <https://huggingface.co/sentence-transformers>`_. force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ url = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, model_name: Optional[str] = 'all-MiniLM-L6-v2', force_reload: bool = False, ) -> None: self.model_name = model_name super().__init__(root, transform, pre_transform, force_reload=force_reload) self.load(self.processed_paths[0], data_cls=HeteroData) @property def raw_file_names(self) -> List[str]: return [ osp.join('ml-latest-small', 'movies.csv'), osp.join('ml-latest-small', 'ratings.csv'), ] @property def processed_file_names(self) -> str: return f'data_{self.model_name}.pt' def download(self) -> None: path = download_url(self.url, self.raw_dir) extract_zip(path, self.raw_dir) os.remove(path) def process(self) -> None: import pandas as pd from sentence_transformers import SentenceTransformer data = HeteroData() df = pd.read_csv(self.raw_paths[0], index_col='movieId') movie_mapping = {idx: i for i, idx in enumerate(df.index)} genres = df['genres'].str.get_dummies('|').values genres = torch.from_numpy(genres).to(torch.float) model = SentenceTransformer(self.model_name) with torch.no_grad(): emb = model.encode(df['title'].values, show_progress_bar=True, convert_to_tensor=True).cpu() data['movie'].x = torch.cat([emb, genres], dim=-1) df = pd.read_csv(self.raw_paths[1]) user_mapping = {idx: i for i, idx in enumerate(df['userId'].unique())} data['user'].num_nodes = len(user_mapping) src = [user_mapping[idx] for idx in df['userId']] dst = [movie_mapping[idx] for idx in df['movieId']] edge_index = torch.tensor([src, dst]) rating = torch.from_numpy(df['rating'].values).to(torch.long) time = torch.from_numpy(df['timestamp'].values).to(torch.long) data['user', 'rates', 'movie'].edge_index = edge_index data['user', 'rates', 'movie'].edge_label = rating data['user', 'rates', 'movie'].time = time if self.pre_transform is not None: data = self.pre_transform(data) self.save([data], self.processed_paths[0])