From 76d07307d97756ccdf63d067529ab5c58d3b9f67 Mon Sep 17 00:00:00 2001 From: Roman Legonkov Date: Sun, 22 Dec 2024 20:35:53 +0300 Subject: [PATCH 1/4] started work with custom model tutorial --- examples/10_custom_model_creation.ipynb | 487 ++++++++++++++++++++++++ 1 file changed, 487 insertions(+) create mode 100644 examples/10_custom_model_creation.ipynb diff --git a/examples/10_custom_model_creation.ipynb b/examples/10_custom_model_creation.ipynb new file mode 100644 index 00000000..a2c4faf8 --- /dev/null +++ b/examples/10_custom_model_creation.ipynb @@ -0,0 +1,487 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "19d09e64-aa80-47e1-9c8b-a5a24564bee7", + "metadata": {}, + "source": [ + "# Example of building custom model with ModelBase class\n", + "\n", + "- Building custom model\n", + "- Visual recommendations checking" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "dce07a5b-2716-41c5-8358-63f590dd69f0", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from rectools.models.base import ModelBase, ModelConfig\n", + "from rectools import Columns\n", + "from rectools.dataset import Dataset\n", + "from scipy.sparse import csr_matrix\n", + "from sklearn.neighbors import NearestNeighbors\n", + "import typing as tp\n", + "import typing_extensions as tpe\n", + "from rectools.models.base import InternalIdsArray\n", + "from rectools.types import *\n", + "from tqdm import tqdm\n", + "from rectools.models.base import Scores\n" + ] + }, + { + "cell_type": "markdown", + "id": "c05f581d-c60b-40c1-8e52-368e1d97ab36", + "metadata": {}, + "source": [ + "## Load data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7814e510-3179-44f0-b116-a30b670fa72e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Archive: ml-1m.zip\n", + " creating: ml-1m/\n", + " inflating: ml-1m/movies.dat \n", + " inflating: ml-1m/ratings.dat \n", + " inflating: ml-1m/README \n", + " inflating: ml-1m/users.dat \n", + "CPU times: user 48.4 ms, sys: 23.4 ms, total: 71.9 ms\n", + "Wall time: 5.12 s\n" + ] + } + ], + "source": [ + "%%time\n", + "!wget -q https://files.grouplens.org/datasets/movielens/ml-1m.zip -O ml-1m.zip\n", + "!unzip -o ml-1m.zip\n", + "!rm ml-1m.zip" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "884ebb14-fcae-4bbc-9221-299b613489f0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1000209, 4)\n", + "CPU times: user 2.95 s, sys: 102 ms, total: 3.05 s\n", + "Wall time: 3.03 s\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idweightdatetime
0111935978300760
116613978302109
219143978301968
3134084978300275
4123555978824291
\n", + "
" + ], + "text/plain": [ + " user_id item_id weight datetime\n", + "0 1 1193 5 978300760\n", + "1 1 661 3 978302109\n", + "2 1 914 3 978301968\n", + "3 1 3408 4 978300275\n", + "4 1 2355 5 978824291" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "ratings = pd.read_csv(\n", + " \"ml-1m/ratings.dat\", \n", + " sep=\"::\",\n", + " engine=\"python\", # Because of 2-chars separators\n", + " header=None,\n", + " names=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],\n", + ")\n", + "print(ratings.shape)\n", + "ratings.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "332cd6dd-d993-47b9-baec-c601b89fba12", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(3883, 3)\n", + "CPU times: user 11.4 ms, sys: 81 μs, total: 11.4 ms\n", + "Wall time: 10.8 ms\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
item_idtitlegenres
01Toy Story (1995)Animation|Children's|Comedy
12Jumanji (1995)Adventure|Children's|Fantasy
23Grumpier Old Men (1995)Comedy|Romance
34Waiting to Exhale (1995)Comedy|Drama
45Father of the Bride Part II (1995)Comedy
\n", + "
" + ], + "text/plain": [ + " item_id title genres\n", + "0 1 Toy Story (1995) Animation|Children's|Comedy\n", + "1 2 Jumanji (1995) Adventure|Children's|Fantasy\n", + "2 3 Grumpier Old Men (1995) Comedy|Romance\n", + "3 4 Waiting to Exhale (1995) Comedy|Drama\n", + "4 5 Father of the Bride Part II (1995) Comedy" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "movies = pd.read_csv(\n", + " \"ml-1m/movies.dat\", \n", + " sep=\"::\",\n", + " engine=\"python\", # Because of 2-chars separators\n", + " header=None,\n", + " names=[Columns.Item, \"title\", \"genres\"],\n", + " encoding_errors=\"ignore\",\n", + ")\n", + "print(movies.shape)\n", + "movies.head()" + ] + }, + { + "cell_type": "markdown", + "id": "3795f0e2-ac3e-4c89-a901-2988522d0629", + "metadata": {}, + "source": [ + "## Build model" + ] + }, + { + "cell_type": "markdown", + "id": "d3f41d9e-b6aa-4a66-b681-2b504c80fbd0", + "metadata": {}, + "source": [ + "### Write a model config inherited from `ModelConfig`" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "95f01cdf-b126-4035-a5b6-867d45aeaa99", + "metadata": {}, + "outputs": [], + "source": [ + "class KnnModelConfig(ModelConfig):\n", + " \"\"\"Config for `KNN` model.\"\"\"\n", + "\n", + " # KNN algorithm hyperparams\n", + " metric: tp.Optional[str] = None\n", + " algorithm: tp.Optional[str] = None\n", + " n_neighbors: tp.Optional[int] = None\n", + " n_jobs: tp.Optional[int] = None" + ] + }, + { + "cell_type": "markdown", + "id": "9cf2011c-6b5a-4707-a640-333664cc85b5", + "metadata": {}, + "source": [ + "### Write a model logic in class inherited from `ModelBase`" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "4f2cb8cf-17c3-4f0e-a979-4411a35af283", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "class KnnModel(ModelBase[KnnModelConfig]):\n", + " \n", + " recommends_for_warm = True\n", + " recommends_for_cold = False\n", + " config_class = KnnModelConfig\n", + "\n", + " def __init__(self,\n", + " metric: tp.Optional[str] = None,\n", + " algorithm: tp.Optional[str] = None,\n", + " n_neighbors: tp.Optional[int] = None,\n", + " n_jobs: tp.Optional[int] = None,\n", + " verbose: int = 0):\n", + " super().__init__(verbose=verbose)\n", + " self.metric = metric\n", + " self.algorithm = algorithm\n", + " self.n_neighbors = n_neighbors\n", + " self.n_jobs = n_jobs\n", + " self.knn_model = NearestNeighbors(metric = self.metric,\n", + " algorithm = self.algorithm,\n", + " n_neighbors = self.n_neighbors,\n", + " n_jobs = self.n_jobs)\n", + " self.all_item_ids: np.ndarray\n", + " self.ui_csr: csr_matrix\n", + "\n", + " def _get_config(self) -> KnnModelConfig:\n", + " return KnnModelConfig(metric=self.metric, algorithm=self.algorithm, n_neighbors=self.n_neighbors, verbose=self.verbose)\n", + "\n", + " @classmethod\n", + " def _from_config(cls, config: KnnModelConfig) -> tpe.Self:\n", + " return cls(metric=config.metric, algorithm=config.algorithm, n_neighbors=config.n_neighbors, verbose=config.verbose)\n", + "\n", + " def _fit(self, dataset: Dataset) -> None: # type: ignore\n", + " self.all_item_ids = dataset.item_id_map.internal_ids\n", + " self.ui_csr = dataset.get_user_item_matrix(include_weights=False, dtype=np.float64)\n", + " self.knn_model.fit(self.ui_csr)\n", + "\n", + " def _recommend_i2i(\n", + " self,\n", + " target_ids: InternalIdsArray,\n", + " dataset: Dataset,\n", + " k: int,\n", + " sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray],\n", + " ) -> tp.Tuple[InternalIds, InternalIds, Scores]:\n", + " sorted_item_ids_to_recommend = dataset.get_user_item_matrix(include_weights=False,\n", + " dtype=np.float64)[sorted_item_ids_to_recommend] if sorted_item_ids_to_recommend is not None else self.all_item_ids\n", + "\n", + " all_target_ids = []\n", + " all_reco_ids: tp.List[np.ndarray] = []\n", + " all_scores: tp.List[np.ndarray] = []\n", + " for target_id in tqdm(target_ids, disable=self.verbose == 0):\n", + " reco_scores, reco_ids = self.knn_model.kneighbors(self.ui_csr[target_id], n_neighbors = k + 1)\n", + " all_target_ids.extend([target_id] * len(reco_ids))\n", + " all_reco_ids.append(reco_ids)\n", + " all_scores.append(reco_scores)\n", + "\n", + " all_reco_ids_arr = np.concatenate(all_reco_ids)\n", + "\n", + " if sorted_item_ids_to_recommend is not None:\n", + " all_reco_ids_arr = sorted_item_ids_to_recommend[all_reco_ids_arr]\n", + "\n", + " return all_target_ids, all_reco_ids_arr, np.concatenate(all_scores)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "e8757369-bec2-48f2-a876-a359d444d706", + "metadata": {}, + "outputs": [], + "source": [ + "model = KnnModel(metric=\"cosine\", algorithm=\"brute\", n_neighbors=20, n_jobs=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "53750dfe-9b46-4b8d-9bfe-e0c9caa41770", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = Dataset.construct(ratings)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "0c279b6f-2287-4198-b958-f00729fabadd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<__main__.KnnModel at 0x797862117af0>" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.fit(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "5c373d0f-bf04-4256-81c6-35f1f124f118", + "metadata": {}, + "outputs": [ + { + "ename": "IndexError", + "evalue": "Index dimension must be 1 or 2", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[46], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_recommend_i2i\u001b[49m\u001b[43m(\u001b[49m\u001b[43mInternalIdsArray\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[40], line 53\u001b[0m, in \u001b[0;36mKnnModel._recommend_i2i\u001b[0;34m(self, target_ids, dataset, k, sorted_item_ids_to_recommend)\u001b[0m\n\u001b[1;32m 51\u001b[0m all_scores: tp\u001b[38;5;241m.\u001b[39mList[np\u001b[38;5;241m.\u001b[39mndarray] \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m target_id \u001b[38;5;129;01min\u001b[39;00m tqdm(target_ids, disable\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m):\n\u001b[0;32m---> 53\u001b[0m reco_scores, reco_ids \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mknn_model\u001b[38;5;241m.\u001b[39mkneighbors(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mui_csr\u001b[49m\u001b[43m[\u001b[49m\u001b[43mtarget_id\u001b[49m\u001b[43m]\u001b[49m, n_neighbors \u001b[38;5;241m=\u001b[39m k \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 54\u001b[0m all_target_ids\u001b[38;5;241m.\u001b[39mextend([target_id] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mlen\u001b[39m(reco_ids))\n\u001b[1;32m 55\u001b[0m all_reco_ids\u001b[38;5;241m.\u001b[39mappend(reco_ids)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/scipy/sparse/_index.py:46\u001b[0m, in \u001b[0;36mIndexMixin.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, key):\n\u001b[0;32m---> 46\u001b[0m row, col \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_indices\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;66;03m# Dispatch to specialized methods.\u001b[39;00m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(row, INT_TYPES):\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/scipy/sparse/_index.py:158\u001b[0m, in \u001b[0;36mIndexMixin._validate_indices\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 156\u001b[0m row \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m M\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(row, \u001b[38;5;28mslice\u001b[39m):\n\u001b[0;32m--> 158\u001b[0m row \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_asindices\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrow\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mM\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m isintlike(col):\n\u001b[1;32m 161\u001b[0m col \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(col)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/scipy/sparse/_index.py:182\u001b[0m, in \u001b[0;36mIndexMixin._asindices\u001b[0;34m(self, idx, length)\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minvalid index\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 181\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m):\n\u001b[0;32m--> 182\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mIndex dimension must be 1 or 2\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x\u001b[38;5;241m.\u001b[39msize \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 185\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n", + "\u001b[0;31mIndexError\u001b[0m: Index dimension must be 1 or 2" + ] + } + ], + "source": [ + "model._recommend_i2i(InternalIdsArray([1]), dataset, 10, None)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "new", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From f21cda21237ba7535a06630a1347e3876298fd66 Mon Sep 17 00:00:00 2001 From: Roman Legonkov Date: Mon, 23 Dec 2024 22:15:21 +0300 Subject: [PATCH 2/4] finished item_to_item recommendations, added visualize sample --- examples/10_custom_model_creation.ipynb | 171 ++++++++++++++++-------- 1 file changed, 118 insertions(+), 53 deletions(-) diff --git a/examples/10_custom_model_creation.ipynb b/examples/10_custom_model_creation.ipynb index a2c4faf8..879676ae 100644 --- a/examples/10_custom_model_creation.ipynb +++ b/examples/10_custom_model_creation.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 275, "id": "dce07a5b-2716-41c5-8358-63f590dd69f0", "metadata": {}, "outputs": [], @@ -30,7 +30,8 @@ "from rectools.models.base import InternalIdsArray\n", "from rectools.types import *\n", "from tqdm import tqdm\n", - "from rectools.models.base import Scores\n" + "from rectools.models.base import Scores\n", + "from rectools.visuals.visual_app import ItemToItemVisualApp" ] }, { @@ -43,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "7814e510-3179-44f0-b116-a30b670fa72e", "metadata": {}, "outputs": [ @@ -52,13 +53,12 @@ "output_type": "stream", "text": [ "Archive: ml-1m.zip\n", - " creating: ml-1m/\n", " inflating: ml-1m/movies.dat \n", " inflating: ml-1m/ratings.dat \n", " inflating: ml-1m/README \n", " inflating: ml-1m/users.dat \n", - "CPU times: user 48.4 ms, sys: 23.4 ms, total: 71.9 ms\n", - "Wall time: 5.12 s\n" + "CPU times: user 46 ms, sys: 22.8 ms, total: 68.8 ms\n", + "Wall time: 5.5 s\n" ] } ], @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 2, "id": "884ebb14-fcae-4bbc-9221-299b613489f0", "metadata": {}, "outputs": [ @@ -80,8 +80,8 @@ "output_type": "stream", "text": [ "(1000209, 4)\n", - "CPU times: user 2.95 s, sys: 102 ms, total: 3.05 s\n", - "Wall time: 3.03 s\n" + "CPU times: user 1.99 s, sys: 99.9 ms, total: 2.09 s\n", + "Wall time: 2.09 s\n" ] }, { @@ -160,7 +160,7 @@ "4 1 2355 5 978824291" ] }, - "execution_count": 37, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -180,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 3, "id": "332cd6dd-d993-47b9-baec-c601b89fba12", "metadata": {}, "outputs": [ @@ -189,8 +189,8 @@ "output_type": "stream", "text": [ "(3883, 3)\n", - "CPU times: user 11.4 ms, sys: 81 μs, total: 11.4 ms\n", - "Wall time: 10.8 ms\n" + "CPU times: user 5.53 ms, sys: 400 μs, total: 5.93 ms\n", + "Wall time: 5.36 ms\n" ] }, { @@ -263,7 +263,7 @@ "4 5 Father of the Bride Part II (1995) Comedy" ] }, - "execution_count": 38, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -300,7 +300,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 4, "id": "95f01cdf-b126-4035-a5b6-867d45aeaa99", "metadata": {}, "outputs": [], @@ -325,19 +325,20 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 266, "id": "4f2cb8cf-17c3-4f0e-a979-4411a35af283", "metadata": {}, "outputs": [], "source": [ - "\n", - "\n", "class KnnModel(ModelBase[KnnModelConfig]):\n", - " \n", - " recommends_for_warm = True\n", + " # There is a sample recsys model inherited from ModelBase\n", + " # Define is able to make cold and warm recommendations\n", + " # Set config class to defined above\n", + " recommends_for_warm = False\n", " recommends_for_cold = False\n", " config_class = KnnModelConfig\n", "\n", + " # Set all hyperparams in __init__\n", " def __init__(self,\n", " metric: tp.Optional[str] = None,\n", " algorithm: tp.Optional[str] = None,\n", @@ -356,25 +357,36 @@ " self.all_item_ids: np.ndarray\n", " self.ui_csr: csr_matrix\n", "\n", + " # Method used to save hyperparams in config\n", " def _get_config(self) -> KnnModelConfig:\n", " return KnnModelConfig(metric=self.metric, algorithm=self.algorithm, n_neighbors=self.n_neighbors, verbose=self.verbose)\n", "\n", + " # Method used to load model params from config\n", " @classmethod\n", " def _from_config(cls, config: KnnModelConfig) -> tpe.Self:\n", " return cls(metric=config.metric, algorithm=config.algorithm, n_neighbors=config.n_neighbors, verbose=config.verbose)\n", "\n", + " # Method used to fit model, there is a sklearn KNN wrapper, so we need to fit KNN model with dataset csr matrix\n", " def _fit(self, dataset: Dataset) -> None: # type: ignore\n", " self.all_item_ids = dataset.item_id_map.internal_ids\n", " self.ui_csr = dataset.get_user_item_matrix(include_weights=False, dtype=np.float64)\n", " self.knn_model.fit(self.ui_csr)\n", "\n", - " def _recommend_i2i(\n", - " self,\n", - " target_ids: InternalIdsArray,\n", - " dataset: Dataset,\n", - " k: int,\n", - " sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray],\n", - " ) -> tp.Tuple[InternalIds, InternalIds, Scores]:\n", + " # Method used to make item-item recommendations, not for direct invokation, used in recommend_to_items method of base class\n", + " # Params:\n", + " # target_ids - InternalIdsArray of item ids for which predictions need to be made\n", + " # dataset - instance of Dataset class\n", + " # k - maximum count of top rated elements presented in recommendations\n", + " # sorted_item_ids_to_recommend - optional InternalIdsArray of item ids from which predictions are made\n", + " # Returns:\n", + " # Equaly sized arrays of target ids, predictions ids, scores\n", + " # in this method you need to ensure, that your realization handles all parameters correctly i.e. \n", + " # it can limit k predictions and limit the set of allowed items.\n", + " def _recommend_i2i(self,\n", + " target_ids: InternalIdsArray,\n", + " dataset: Dataset,\n", + " k: int,\n", + " sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray]) -> tp.Tuple[InternalIds, InternalIds, Scores]:\n", " sorted_item_ids_to_recommend = dataset.get_user_item_matrix(include_weights=False,\n", " dtype=np.float64)[sorted_item_ids_to_recommend] if sorted_item_ids_to_recommend is not None else self.all_item_ids\n", "\n", @@ -383,21 +395,31 @@ " all_scores: tp.List[np.ndarray] = []\n", " for target_id in tqdm(target_ids, disable=self.verbose == 0):\n", " reco_scores, reco_ids = self.knn_model.kneighbors(self.ui_csr[target_id], n_neighbors = k + 1)\n", - " all_target_ids.extend([target_id] * len(reco_ids))\n", - " all_reco_ids.append(reco_ids)\n", - " all_scores.append(reco_scores)\n", + " all_target_ids.extend([target_id] * len(reco_ids.tolist()[0]))\n", + " all_reco_ids.extend(reco_ids.tolist())\n", + " all_scores.extend(reco_scores.tolist())\n", "\n", + " all_target_ids = np.array(all_target_ids) \n", " all_reco_ids_arr = np.concatenate(all_reco_ids)\n", + " all_reco_scores_array = np.concatenate(all_scores)\n", + " valid_indices = all_reco_ids_arr < len(sorted_item_ids_to_recommend)\n", "\n", + " all_reco_ids_arr = all_reco_ids_arr[valid_indices]\n", + " all_target_ids = all_target_ids[valid_indices]\n", + " all_reco_scores_array = all_reco_scores_array[valid_indices]\n", + " \n", " if sorted_item_ids_to_recommend is not None:\n", - " all_reco_ids_arr = sorted_item_ids_to_recommend[all_reco_ids_arr]\n", + " items_indeces = np.isin(all_reco_ids_arr, sorted_item_ids_to_recommend)\n", + " all_reco_ids_arr = all_reco_ids_arr[items_indeces]\n", + " all_target_ids = all_target_ids[items_indeces]\n", + " all_reco_scores_array = all_reco_scores_array[items_indeces]\n", "\n", - " return all_target_ids, all_reco_ids_arr, np.concatenate(all_scores)" + " return all_target_ids, all_reco_ids_arr, all_reco_scores_array" ] }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 267, "id": "e8757369-bec2-48f2-a876-a359d444d706", "metadata": {}, "outputs": [], @@ -407,7 +429,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 268, "id": "53750dfe-9b46-4b8d-9bfe-e0c9caa41770", "metadata": {}, "outputs": [], @@ -417,17 +439,17 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 269, "id": "0c279b6f-2287-4198-b958-f00729fabadd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "<__main__.KnnModel at 0x797862117af0>" + "<__main__.KnnModel at 0x7db3e9fa4bb0>" ] }, - "execution_count": 43, + "execution_count": 269, "metadata": {}, "output_type": "execute_result" } @@ -436,36 +458,79 @@ "model.fit(dataset)" ] }, + { + "cell_type": "markdown", + "id": "ac06ff1f-d1c2-4fcd-b328-70197e2b02a5", + "metadata": {}, + "source": [ + "## Use model to recommend similar items" + ] + }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 283, "id": "5c373d0f-bf04-4256-81c6-35f1f124f118", "metadata": {}, + "outputs": [], + "source": [ + "reco = model.recommend_to_items([1,7,6,2,3,5], dataset, 10)\n", + "reco[Columns.Model] = \"KnnCustomModel\"" + ] + }, + { + "cell_type": "code", + "execution_count": 290, + "id": "f7451d5d-5689-40ff-8f1f-fd9c163c6833", + "metadata": {}, + "outputs": [], + "source": [ + "selected_items = {\"item_one\": 3}\n", + "formatters = {\"item_id\": lambda x: f\"{x}\"}" + ] + }, + { + "cell_type": "code", + "execution_count": 291, + "id": "d34cd895-5957-4176-9e59-dc2b74f1079d", + "metadata": {}, "outputs": [ { - "ename": "IndexError", - "evalue": "Index dimension must be 1 or 2", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[46], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_recommend_i2i\u001b[49m\u001b[43m(\u001b[49m\u001b[43mInternalIdsArray\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[40], line 53\u001b[0m, in \u001b[0;36mKnnModel._recommend_i2i\u001b[0;34m(self, target_ids, dataset, k, sorted_item_ids_to_recommend)\u001b[0m\n\u001b[1;32m 51\u001b[0m all_scores: tp\u001b[38;5;241m.\u001b[39mList[np\u001b[38;5;241m.\u001b[39mndarray] \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m target_id \u001b[38;5;129;01min\u001b[39;00m tqdm(target_ids, disable\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m):\n\u001b[0;32m---> 53\u001b[0m reco_scores, reco_ids \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mknn_model\u001b[38;5;241m.\u001b[39mkneighbors(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mui_csr\u001b[49m\u001b[43m[\u001b[49m\u001b[43mtarget_id\u001b[49m\u001b[43m]\u001b[49m, n_neighbors \u001b[38;5;241m=\u001b[39m k \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 54\u001b[0m all_target_ids\u001b[38;5;241m.\u001b[39mextend([target_id] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mlen\u001b[39m(reco_ids))\n\u001b[1;32m 55\u001b[0m all_reco_ids\u001b[38;5;241m.\u001b[39mappend(reco_ids)\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/scipy/sparse/_index.py:46\u001b[0m, in \u001b[0;36mIndexMixin.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, key):\n\u001b[0;32m---> 46\u001b[0m row, col \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_indices\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;66;03m# Dispatch to specialized methods.\u001b[39;00m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(row, INT_TYPES):\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/scipy/sparse/_index.py:158\u001b[0m, in \u001b[0;36mIndexMixin._validate_indices\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 156\u001b[0m row \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m M\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(row, \u001b[38;5;28mslice\u001b[39m):\n\u001b[0;32m--> 158\u001b[0m row \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_asindices\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrow\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mM\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m isintlike(col):\n\u001b[1;32m 161\u001b[0m col \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(col)\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/scipy/sparse/_index.py:182\u001b[0m, in \u001b[0;36mIndexMixin._asindices\u001b[0;34m(self, idx, length)\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minvalid index\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 181\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m):\n\u001b[0;32m--> 182\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mIndex dimension must be 1 or 2\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x\u001b[38;5;241m.\u001b[39msize \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 185\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n", - "\u001b[0;31mIndexError\u001b[0m: Index dimension must be 1 or 2" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c5cb24fd43314999966c70919be0aa5c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(ToggleButtons(button_style='warning', description='Target:', options=('item_one',), value='item…" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "model._recommend_i2i(InternalIdsArray([1]), dataset, 10, None)" + "app = ItemToItemVisualApp.construct(\n", + " reco=reco,\n", + " item_data=movies,\n", + " selected_items=selected_items,\n", + " formatters=formatters,\n", + " auto_display=True\n", + ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4825b91-eaa0-405e-9b9e-2523fc6ba66b", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "new", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, From 3d9f5b5e4446c3c6a7676d9d312f870e88633d72 Mon Sep 17 00:00:00 2001 From: Roman Legonkov Date: Mon, 23 Dec 2024 22:22:08 +0300 Subject: [PATCH 3/4] changelog update --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ac9a48cf..f38d4099 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added +- Tutorial for making custom model inherited from ModelBase ([#236](https://github.com/MobileTeleSystems/RecTools/pull/236)) ## [0.9.0] - 11.12.2024 From 8458e0f3f93905a875b58cdf4ab606d612fe2a22 Mon Sep 17 00:00:00 2001 From: Roman Legonkov Date: Tue, 24 Dec 2024 22:19:47 +0300 Subject: [PATCH 4/4] added u2i recommendation implimentation --- examples/10_custom_model_creation.ipynb | 424 +++++++++++++++++++++--- 1 file changed, 384 insertions(+), 40 deletions(-) diff --git a/examples/10_custom_model_creation.ipynb b/examples/10_custom_model_creation.ipynb index 879676ae..7ad943f4 100644 --- a/examples/10_custom_model_creation.ipynb +++ b/examples/10_custom_model_creation.ipynb @@ -13,14 +13,14 @@ }, { "cell_type": "code", - "execution_count": 275, + "execution_count": 40, "id": "dce07a5b-2716-41c5-8358-63f590dd69f0", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", - "from rectools.models.base import ModelBase, ModelConfig\n", + "from rectools.models.base import ModelBase, ModelConfig, Scores, ScoresArray\n", "from rectools import Columns\n", "from rectools.dataset import Dataset\n", "from scipy.sparse import csr_matrix\n", @@ -30,8 +30,10 @@ "from rectools.models.base import InternalIdsArray\n", "from rectools.types import *\n", "from tqdm import tqdm\n", - "from rectools.models.base import Scores\n", - "from rectools.visuals.visual_app import ItemToItemVisualApp" + "from rectools.utils import fast_isin_for_sorted_test_elements\n", + "from rectools.models.utils import get_viewed_item_ids\n", + "from rectools.visuals.visual_app import ItemToItemVisualApp, VisualApp\n", + "import random" ] }, { @@ -57,8 +59,8 @@ " inflating: ml-1m/ratings.dat \n", " inflating: ml-1m/README \n", " inflating: ml-1m/users.dat \n", - "CPU times: user 46 ms, sys: 22.8 ms, total: 68.8 ms\n", - "Wall time: 5.5 s\n" + "CPU times: user 27.2 ms, sys: 23 ms, total: 50.2 ms\n", + "Wall time: 3.21 s\n" ] } ], @@ -71,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "884ebb14-fcae-4bbc-9221-299b613489f0", "metadata": {}, "outputs": [ @@ -80,8 +82,8 @@ "output_type": "stream", "text": [ "(1000209, 4)\n", - "CPU times: user 1.99 s, sys: 99.9 ms, total: 2.09 s\n", - "Wall time: 2.09 s\n" + "CPU times: user 2.04 s, sys: 93.3 ms, total: 2.13 s\n", + "Wall time: 2.13 s\n" ] }, { @@ -160,7 +162,7 @@ "4 1 2355 5 978824291" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -180,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "332cd6dd-d993-47b9-baec-c601b89fba12", "metadata": {}, "outputs": [ @@ -189,8 +191,8 @@ "output_type": "stream", "text": [ "(3883, 3)\n", - "CPU times: user 5.53 ms, sys: 400 μs, total: 5.93 ms\n", - "Wall time: 5.36 ms\n" + "CPU times: user 4.71 ms, sys: 792 μs, total: 5.5 ms\n", + "Wall time: 4.99 ms\n" ] }, { @@ -263,7 +265,7 @@ "4 5 Father of the Bride Part II (1995) Comedy" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -300,19 +302,63 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "95f01cdf-b126-4035-a5b6-867d45aeaa99", "metadata": {}, "outputs": [], "source": [ - "class KnnModelConfig(ModelConfig):\n", + "class MixedKnnRandomModelConfig(ModelConfig):\n", " \"\"\"Config for `KNN` model.\"\"\"\n", "\n", " # KNN algorithm hyperparams\n", " metric: tp.Optional[str] = None\n", " algorithm: tp.Optional[str] = None\n", " n_neighbors: tp.Optional[int] = None\n", - " n_jobs: tp.Optional[int] = None" + " n_jobs: tp.Optional[int] = None\n", + " random_state: tp.Optional[int] = None" + ] + }, + { + "cell_type": "markdown", + "id": "6eecbbd5", + "metadata": {}, + "source": [ + "### Define a `_RandomSampler` and `_RandomGen` class for random recommendations" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "60255323", + "metadata": {}, + "outputs": [], + "source": [ + "class _RandomGen:\n", + " def __init__(self, random_state: tp.Optional[int] = None) -> None:\n", + " self.python_gen = random.Random(random_state) # nosec\n", + " self.np_gen = np.random.default_rng(random_state)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e3283dfd", + "metadata": {}, + "outputs": [], + "source": [ + "class _RandomSampler:\n", + " def __init__(self, values: np.ndarray, random_gen: _RandomGen) -> None:\n", + " self.python_gen = random_gen.python_gen\n", + " self.np_gen = random_gen.np_gen\n", + " self.values = values\n", + " self.values_list = list(values) # for random.sample\n", + "\n", + " def sample(self, n: int) -> np.ndarray:\n", + " if n < 25: # Empiric value, for optimization\n", + " sampled = np.asarray(self.python_gen.sample(self.values_list, n))\n", + " else:\n", + " sampled = self.np_gen.choice(self.values, n, replace=False)\n", + " return sampled\n" ] }, { @@ -325,18 +371,19 @@ }, { "cell_type": "code", - "execution_count": 266, + "execution_count": 28, "id": "4f2cb8cf-17c3-4f0e-a979-4411a35af283", "metadata": {}, "outputs": [], "source": [ - "class KnnModel(ModelBase[KnnModelConfig]):\n", - " # There is a sample recsys model inherited from ModelBase\n", + "class MixedKnnRandomModel(ModelBase[MixedKnnRandomModelConfig]):\n", + " # There is a sample recsys model inherited from ModelBase, model is mixed KNN wrapper (i2i) and random model (u2i)\n", + " # You can mix other models as well.\n", " # Define is able to make cold and warm recommendations\n", " # Set config class to defined above\n", " recommends_for_warm = False\n", " recommends_for_cold = False\n", - " config_class = KnnModelConfig\n", + " config_class = MixedKnnRandomModelConfig\n", "\n", " # Set all hyperparams in __init__\n", " def __init__(self,\n", @@ -344,6 +391,7 @@ " algorithm: tp.Optional[str] = None,\n", " n_neighbors: tp.Optional[int] = None,\n", " n_jobs: tp.Optional[int] = None,\n", + " random_state: tp.Optional[int] = None,\n", " verbose: int = 0):\n", " super().__init__(verbose=verbose)\n", " self.metric = metric\n", @@ -354,17 +402,19 @@ " algorithm = self.algorithm,\n", " n_neighbors = self.n_neighbors,\n", " n_jobs = self.n_jobs)\n", + " self.random_state = random_state\n", + " self.random_gen = _RandomGen(random_state)\n", " self.all_item_ids: np.ndarray\n", " self.ui_csr: csr_matrix\n", "\n", " # Method used to save hyperparams in config\n", - " def _get_config(self) -> KnnModelConfig:\n", - " return KnnModelConfig(metric=self.metric, algorithm=self.algorithm, n_neighbors=self.n_neighbors, verbose=self.verbose)\n", + " def _get_config(self) -> MixedKnnRandomModelConfig:\n", + " return MixedKnnRandomModelConfig(metric=self.metric, algorithm=self.algorithm, n_neighbors=self.n_neighbors, random_state=self.random_state, verbose=self.verbose)\n", "\n", " # Method used to load model params from config\n", " @classmethod\n", - " def _from_config(cls, config: KnnModelConfig) -> tpe.Self:\n", - " return cls(metric=config.metric, algorithm=config.algorithm, n_neighbors=config.n_neighbors, verbose=config.verbose)\n", + " def _from_config(cls, config: MixedKnnRandomModelConfig) -> tpe.Self:\n", + " return cls(metric=config.metric, algorithm=config.algorithm, n_neighbors=config.n_neighbors, random_state=config.random_state, verbose=config.verbose)\n", "\n", " # Method used to fit model, there is a sklearn KNN wrapper, so we need to fit KNN model with dataset csr matrix\n", " def _fit(self, dataset: Dataset) -> None: # type: ignore\n", @@ -414,22 +464,70 @@ " all_target_ids = all_target_ids[items_indeces]\n", " all_reco_scores_array = all_reco_scores_array[items_indeces]\n", "\n", - " return all_target_ids, all_reco_ids_arr, all_reco_scores_array" + " return all_target_ids, all_reco_ids_arr, all_reco_scores_array\n", + " \n", + " # Method used to make user-item recommendations, not for direct invokation, used in recommend method of base class\n", + " # Params:\n", + " # target_ids - InternalIdsArray of user ids for which predictions need to be made\n", + " # dataset - instance of Dataset class\n", + " # k - maximum count of top rated elements presented in recommendations\n", + " # sorted_item_ids_to_recommend - optional InternalIdsArray of item ids from which predictions are made\n", + " # Returns:\n", + " # Equaly sized arrays of target ids, predictions ids, scores\n", + " # in this method you need to ensure, that your realization handles all parameters correctly i.e. \n", + " # it can limit k predictions and limit the set of allowed items.\n", + " def _recommend_u2i(\n", + " self,\n", + " user_ids: InternalIdsArray,\n", + " dataset: Dataset,\n", + " k: int,\n", + " filter_viewed: bool,\n", + " sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray],\n", + " ) -> tp.Tuple[InternalIds, InternalIds, Scores]:\n", + " if filter_viewed:\n", + " user_items = dataset.get_user_item_matrix(include_weights=False)\n", + "\n", + " item_ids = sorted_item_ids_to_recommend if sorted_item_ids_to_recommend is not None else self.all_item_ids\n", + " sampler = _RandomSampler(item_ids, self.random_gen)\n", + "\n", + " all_user_ids = []\n", + " all_reco_ids: tp.List[InternalId] = []\n", + " all_scores: tp.List[float] = []\n", + " for user_id in tqdm(user_ids, disable=self.verbose == 0):\n", + " if filter_viewed:\n", + " viewed_ids = get_viewed_item_ids(user_items, user_id) # sorted\n", + " n_reco = k + viewed_ids.size\n", + " else:\n", + " n_reco = k\n", + "\n", + " n_reco = min(n_reco, item_ids.size)\n", + " reco_ids = sampler.sample(n_reco)\n", + "\n", + " if filter_viewed:\n", + " reco_ids = reco_ids[fast_isin_for_sorted_test_elements(reco_ids, viewed_ids, invert=True)][:k]\n", + "\n", + " reco_scores = np.arange(reco_ids.size, 0, -1)\n", + "\n", + " all_user_ids.extend([user_id] * len(reco_ids))\n", + " all_reco_ids.extend(reco_ids.tolist())\n", + " all_scores.extend(reco_scores.tolist())\n", + "\n", + " return all_user_ids, all_reco_ids, all_scores" ] }, { "cell_type": "code", - "execution_count": 267, + "execution_count": 29, "id": "e8757369-bec2-48f2-a876-a359d444d706", "metadata": {}, "outputs": [], "source": [ - "model = KnnModel(metric=\"cosine\", algorithm=\"brute\", n_neighbors=20, n_jobs=-1)" + "model = MixedKnnRandomModel(metric=\"cosine\", algorithm=\"brute\", n_neighbors=20, n_jobs=-1, random_state=20)" ] }, { "cell_type": "code", - "execution_count": 268, + "execution_count": 30, "id": "53750dfe-9b46-4b8d-9bfe-e0c9caa41770", "metadata": {}, "outputs": [], @@ -439,17 +537,17 @@ }, { "cell_type": "code", - "execution_count": 269, + "execution_count": 31, "id": "0c279b6f-2287-4198-b958-f00729fabadd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "<__main__.KnnModel at 0x7db3e9fa4bb0>" + "<__main__.MixedKnnRandomModel at 0x7b349d4f63e0>" ] }, - "execution_count": 269, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -468,7 +566,7 @@ }, { "cell_type": "code", - "execution_count": 283, + "execution_count": 36, "id": "5c373d0f-bf04-4256-81c6-35f1f124f118", "metadata": {}, "outputs": [], @@ -479,7 +577,104 @@ }, { "cell_type": "code", - "execution_count": 290, + "execution_count": 37, + "id": "591c59da", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
target_item_iditem_idscorerankmodel
0118950.7497831KnnCustomModel
1119890.7503852KnnCustomModel
214580.7599023KnnCustomModel
3119060.7599024KnnCustomModel
418770.7682705KnnCustomModel
\n", + "
" + ], + "text/plain": [ + " target_item_id item_id score rank model\n", + "0 1 1895 0.749783 1 KnnCustomModel\n", + "1 1 1989 0.750385 2 KnnCustomModel\n", + "2 1 458 0.759902 3 KnnCustomModel\n", + "3 1 1906 0.759902 4 KnnCustomModel\n", + "4 1 877 0.768270 5 KnnCustomModel" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reco.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 38, "id": "f7451d5d-5689-40ff-8f1f-fd9c163c6833", "metadata": {}, "outputs": [], @@ -490,14 +685,14 @@ }, { "cell_type": "code", - "execution_count": 291, + "execution_count": 42, "id": "d34cd895-5957-4176-9e59-dc2b74f1079d", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c5cb24fd43314999966c70919be0aa5c", + "model_id": "57cba3f32d84411b8fa97416e828e1b8", "version_major": 2, "version_minor": 0 }, @@ -519,18 +714,167 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "f225a5c2", + "metadata": {}, + "source": [ + "# Use model to recommend movies for a specific users" + ] + }, { "cell_type": "code", - "execution_count": null, - "id": "b4825b91-eaa0-405e-9b9e-2523fc6ba66b", + "execution_count": 44, + "id": "850b4dcb", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "reco = model.recommend([1,7,6,2,3,5], dataset, 10, filter_viewed=True)\n", + "reco[Columns.Model] = \"KnnCustomModel\"" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "d03c5383", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscorerankmodel
01213810.01KnnCustomModel
11609.02KnnCustomModel
215728.03KnnCustomModel
3131557.04KnnCustomModel
4117606.05KnnCustomModel
\n", + "
" + ], + "text/plain": [ + " user_id item_id score rank model\n", + "0 1 2138 10.0 1 KnnCustomModel\n", + "1 1 60 9.0 2 KnnCustomModel\n", + "2 1 572 8.0 3 KnnCustomModel\n", + "3 1 3155 7.0 4 KnnCustomModel\n", + "4 1 1760 6.0 5 KnnCustomModel" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reco.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "b4825b91-eaa0-405e-9b9e-2523fc6ba66b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "467aa87b78d347288a7c093029d9fb82", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(ToggleButtons(button_style='warning', description='Target:', options=('user_one',), value='user…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "selected_users = {\"user_one\": 3}\n", + "app = VisualApp.construct(\n", + " reco=reco,\n", + " interactions=ratings,\n", + " item_data=movies,\n", + " selected_users=selected_users,\n", + " formatters=formatters \n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "bfc37d9a", + "metadata": {}, + "source": [ + "# Conclusion\n", + "You can create custom models with any requirements by inheriting from BaseModel and implementing necessary methods." + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "new", "language": "python", "name": "python3" },