diff --git a/learning_resources/etl/loaders_test.py b/learning_resources/etl/loaders_test.py index 91c16bdb8c..72b074754a 100644 --- a/learning_resources/etl/loaders_test.py +++ b/learning_resources/etl/loaders_test.py @@ -131,6 +131,14 @@ def mock_duplicates(mocker): ) +@pytest.fixture +def mock_get_similar_topics_qdrant(mocker): + mocker.patch( + "learning_resources_search.plugins.get_similar_topics_qdrant", + return_value=["topic1", "topic2"], + ) + + @pytest.fixture(autouse=True) def mock_upsert_tasks(mocker): """Mock out the upsert task helpers""" @@ -1465,9 +1473,10 @@ def test_load_video(mocker, video_exists, is_published, pass_topics): assert getattr(result, key) == value, f"Property {key} should equal {value}" -def test_load_videos(): +def test_load_videos(mocker, mock_get_similar_topics_qdrant): """Verify that load_videos loads a list of videos""" assert Video.objects.count() == 0 + video_resources = [video.learning_resource for video in VideoFactory.build_batch(5)] videos_data = [ { @@ -1486,13 +1495,14 @@ def test_load_videos(): @pytest.mark.parametrize("playlist_exists", [True, False]) -def test_load_playlist(mocker, playlist_exists): +def test_load_playlist(mocker, playlist_exists, mock_get_similar_topics_qdrant): """Test load_playlist""" expected_topics = [{"name": "Biology"}, {"name": "Physics"}] [ LearningResourceTopicFactory.create(name=topic["name"]) for topic in expected_topics ] + mock_most_common_topics = mocker.patch( "learning_resources.etl.loaders.most_common_topics", return_value=expected_topics, @@ -1904,7 +1914,7 @@ def test_course_with_unpublished_force_ingest_is_test_mode(): @pytest.mark.django_db -def test_load_articles(mocker, climate_platform): +def test_load_articles(mocker, climate_platform, mock_get_similar_topics_qdrant): articles_data = [ { "title": "test", diff --git a/learning_resources_search/api.py b/learning_resources_search/api.py index 6ca4c3422b..b6108f04fa 100644 --- a/learning_resources_search/api.py +++ b/learning_resources_search/api.py @@ -35,7 +35,11 @@ adjust_search_for_percolator, document_percolated_actions, ) -from vector_search.constants import RESOURCES_COLLECTION_NAME +from vector_search.constants import ( + RESOURCES_COLLECTION_NAME, + TOPICS_COLLECTION_NAME, +) +from vector_search.encoders.utils import dense_encoder log = logging.getLogger(__name__) @@ -830,6 +834,51 @@ def user_subscribed_to_query( ) +def get_similar_topics_qdrant( + resource: LearningResource, value_doc: dict, num_topics: int +) -> list[str]: + from vector_search.utils import qdrant_client, vector_point_id + + """ + Get a list of similar topics based on vector similarity + + Args: + value_doc (dict): + a document representing the data fields we want to search with + num_topics (int): + number of topics to return + Returns: + list of str: + list of topic values + """ + encoder = dense_encoder() + client = qdrant_client() + + response = client.retrieve( + collection_name=RESOURCES_COLLECTION_NAME, + ids=[vector_point_id(resource.readable_id)], + with_vectors=True, + ) + + embedding_context = "\n".join( + [value_doc[key] for key in value_doc if value_doc[key] is not None] + ) + if response and len(response) > 0: + embeddings = response[0].vector.get(encoder.model_short_name()) + else: + embeddings = encoder.embed(embedding_context) + + return [ + hit["name"] + for hit in _qdrant_similar_results( + input_query=embeddings, + num_resources=num_topics, + collection_name=TOPICS_COLLECTION_NAME, + score_threshold=0.2, + ) + ] + + def get_similar_topics( value_doc: dict, num_topics: int, min_term_freq: int, min_doc_freq: int ) -> list[str]: @@ -909,7 +958,12 @@ def get_similar_resources( ) -def _qdrant_similar_results(doc, num_resources): +def _qdrant_similar_results( + input_query, + num_resources=6, + collection_name=RESOURCES_COLLECTION_NAME, + score_threshold=0, +): """ Get similar resources from qdrant @@ -924,9 +978,7 @@ def _qdrant_similar_results(doc, num_resources): list of serialized resources """ from vector_search.utils import ( - dense_encoder, qdrant_client, - vector_point_id, ) encoder = dense_encoder() @@ -934,10 +986,11 @@ def _qdrant_similar_results(doc, num_resources): return [ hit.payload for hit in client.query_points( - collection_name=RESOURCES_COLLECTION_NAME, - query=vector_point_id(doc["readable_id"]), + collection_name=collection_name, + query=input_query, limit=num_resources, using=encoder.model_short_name(), + score_threshold=score_threshold, ).points ] @@ -956,7 +1009,12 @@ def get_similar_resources_qdrant(value_doc: dict, num_resources: int): list of str: list of learning resources """ - hits = _qdrant_similar_results(value_doc, num_resources) + from vector_search.utils import vector_point_id + + hits = _qdrant_similar_results( + input_query=vector_point_id(value_doc["readable_id"]), + num_resources=num_resources, + ) return ( LearningResource.objects.for_search_serialization() .filter( diff --git a/learning_resources_search/api_test.py b/learning_resources_search/api_test.py index c60ac3e0ee..25f5117ef2 100644 --- a/learning_resources_search/api_test.py +++ b/learning_resources_search/api_test.py @@ -1,6 +1,6 @@ """Search API function tests""" -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock import pytest from freezegun import freeze_time @@ -21,6 +21,7 @@ generate_sort_clause, generate_suggest_clause, get_similar_topics, + get_similar_topics_qdrant, percolate_matches_for_document, relevant_indexes, ) @@ -3266,3 +3267,39 @@ def test_dev_mode(dev_mode): assert construct_search(search_params).to_dict().get("explain") else: assert construct_search(search_params).to_dict().get("explain") is None + + +@pytest.mark.django_db +def test_get_similar_topics_qdrant_uses_cached_embedding(mocker): + """ + Test that get_similar_topics_qdrant uses a cached embedding when available + """ + resource = MagicMock() + resource.readable_id = "test-resource" + value_doc = {"title": "Test Title", "description": "Test Description"} + num_topics = 3 + + mock_encoder = mocker.patch("learning_resources_search.api.dense_encoder") + encoder_instance = mock_encoder.return_value + encoder_instance.model_short_name.return_value = "test-model" + encoder_instance.embed.return_value = [0.1, 0.2, 0.3] + + mock_client = mocker.patch("vector_search.utils.qdrant_client") + client_instance = mock_client.return_value + + # Simulate a cached embedding in the response + client_instance.retrieve.return_value = [ + MagicMock(vector={"test-model": [0.9, 0.8, 0.7]}) + ] + + mocker.patch( + "learning_resources_search.api._qdrant_similar_results", + return_value=[{"name": "topic1"}, {"name": "topic2"}], + ) + + result = get_similar_topics_qdrant(resource, value_doc, num_topics) + + # Assert that embed was NOT called (cached embedding used) + encoder_instance.embed.assert_not_called() + # Assert that the result is as expected + assert result == ["topic1", "topic2"] diff --git a/learning_resources_search/plugins.py b/learning_resources_search/plugins.py index 8eca2f981a..f938d316e1 100644 --- a/learning_resources_search/plugins.py +++ b/learning_resources_search/plugins.py @@ -7,7 +7,7 @@ from django.conf import settings as django_settings from learning_resources_search import tasks -from learning_resources_search.api import get_similar_topics +from learning_resources_search.api import get_similar_topics_qdrant from learning_resources_search.constants import ( COURSE_TYPE, PERCOLATE_INDEX_TYPE, @@ -125,11 +125,10 @@ def resource_similar_topics(self, resource) -> list[dict]: "full_description": resource.full_description, } - topic_names = get_similar_topics( + topic_names = get_similar_topics_qdrant( + resource, text_doc, settings.OPEN_VIDEO_MAX_TOPICS, - settings.OPEN_VIDEO_MIN_TERM_FREQ, - settings.OPEN_VIDEO_MIN_DOC_FREQ, ) return [{"name": topic_name} for topic_name in topic_names] diff --git a/learning_resources_search/plugins_test.py b/learning_resources_search/plugins_test.py index 4bb24b8ca9..8eacaa65d0 100644 --- a/learning_resources_search/plugins_test.py +++ b/learning_resources_search/plugins_test.py @@ -128,19 +128,18 @@ def test_resource_similar_topics(mocker, settings): """The plugin function should return expected topics for a resource""" expected_topics = ["topic1", "topic2"] mock_similar_topics = mocker.patch( - "learning_resources_search.plugins.get_similar_topics", + "learning_resources_search.plugins.get_similar_topics_qdrant", return_value=expected_topics, ) resource = LearningResourceFactory.create() topics = SearchIndexPlugin().resource_similar_topics(resource) assert topics == [{"name": topic} for topic in expected_topics] mock_similar_topics.assert_called_once_with( + resource, { "title": resource.title, "description": resource.description, "full_description": resource.full_description, }, settings.OPEN_VIDEO_MAX_TOPICS, - settings.OPEN_VIDEO_MIN_TERM_FREQ, - settings.OPEN_VIDEO_MIN_DOC_FREQ, ) diff --git a/main/settings_course_etl.py b/main/settings_course_etl.py index a5bf02954e..6f70c4d6c4 100644 --- a/main/settings_course_etl.py +++ b/main/settings_course_etl.py @@ -96,7 +96,7 @@ # course catalog video etl settings OPEN_VIDEO_DATA_BRANCH = get_string("OPEN_VIDEO_DATA_BRANCH", "master") OPEN_VIDEO_USER_LIST_OWNER = get_string("OPEN_VIDEO_USER_LIST_OWNER", None) -OPEN_VIDEO_MAX_TOPICS = get_int("OPEN_VIDEO_MAX_TOPICS", 3) +OPEN_VIDEO_MAX_TOPICS = get_int("OPEN_VIDEO_MAX_TOPICS", 2) OPEN_VIDEO_MIN_TERM_FREQ = get_int("OPEN_VIDEO_MIN_TERM_FREQ", 1) OPEN_VIDEO_MIN_DOC_FREQ = get_int("OPEN_VIDEO_MIN_DOC_FREQ", 15) diff --git a/vector_search/constants.py b/vector_search/constants.py index 97445d892d..d057c2ea98 100644 --- a/vector_search/constants.py +++ b/vector_search/constants.py @@ -3,6 +3,7 @@ RESOURCES_COLLECTION_NAME = f"{settings.QDRANT_BASE_COLLECTION_NAME}.resources" CONTENT_FILES_COLLECTION_NAME = f"{settings.QDRANT_BASE_COLLECTION_NAME}.content_files" +TOPICS_COLLECTION_NAME = f"{settings.QDRANT_BASE_COLLECTION_NAME}.topics" QDRANT_CONTENT_FILE_PARAM_MAP = { "key": "key", @@ -43,6 +44,10 @@ } +QDRANT_TOPICS_PARAM_MAP = { + "name": "name", +} + QDRANT_LEARNING_RESOURCE_INDEXES = { "readable_id": models.PayloadSchemaType.KEYWORD, "resource_type": models.PayloadSchemaType.KEYWORD, @@ -82,3 +87,7 @@ "edx_block_id": models.PayloadSchemaType.KEYWORD, "url": models.PayloadSchemaType.KEYWORD, } + +QDRANT_TOPIC_INDEXES = { + "name": models.PayloadSchemaType.KEYWORD, +} diff --git a/vector_search/management/commands/sync_topic_embeddings.py b/vector_search/management/commands/sync_topic_embeddings.py new file mode 100644 index 0000000000..8a5d1a8487 --- /dev/null +++ b/vector_search/management/commands/sync_topic_embeddings.py @@ -0,0 +1,27 @@ +"""Management command to update or create the topics collection in Qdrant""" + +from django.core.management.base import BaseCommand, CommandError + +from main.utils import clear_search_cache, now_in_utc +from vector_search.tasks import sync_topics + + +class Command(BaseCommand): + """Syncs embeddings for topics in Qdrant""" + + help = "update or create the topics collection in Qdrant" + + def handle(self, *args, **options): # noqa: ARG002 + """Sync the topics collection""" + task = sync_topics.apply() + self.stdout.write("Waiting on task...") + start = now_in_utc() + error = task.get() + if error: + msg = f"Geenerate embeddings errored: {error}" + raise CommandError(msg) + clear_search_cache() + total_seconds = (now_in_utc() - start).total_seconds() + self.stdout.write( + f"Embeddings generated and stored, took {total_seconds} seconds" + ) diff --git a/vector_search/tasks.py b/vector_search/tasks.py index ce7ab862ca..3e497a3d9e 100644 --- a/vector_search/tasks.py +++ b/vector_search/tasks.py @@ -32,7 +32,11 @@ chunks, now_in_utc, ) -from vector_search.utils import embed_learning_resources, remove_qdrant_records +from vector_search.utils import ( + embed_learning_resources, + embed_topics, + remove_qdrant_records, +) log = logging.getLogger(__name__) @@ -362,3 +366,11 @@ def remove_run_content_files(run_id): for ids in chunks(content_file_ids, chunk_size=settings.QDRANT_CHUNK_SIZE) ] ) + + +@app.task +def sync_topics(): + """ + Sync topics to the Qdrant collection + """ + embed_topics() diff --git a/vector_search/utils.py b/vector_search/utils.py index ab9160b21b..1f9d55a735 100644 --- a/vector_search/utils.py +++ b/vector_search/utils.py @@ -2,12 +2,17 @@ import uuid from django.conf import settings +from django.db.models import Q from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_experimental.text_splitter import SemanticChunker from qdrant_client import QdrantClient, models from learning_resources.content_summarizer import ContentSummarizer -from learning_resources.models import ContentFile, LearningResource +from learning_resources.models import ( + ContentFile, + LearningResource, + LearningResourceTopic, +) from learning_resources.serializers import ( ContentFileSerializer, LearningResourceMetadataDisplaySerializer, @@ -28,7 +33,10 @@ QDRANT_CONTENT_FILE_PARAM_MAP, QDRANT_LEARNING_RESOURCE_INDEXES, QDRANT_RESOURCE_PARAM_MAP, + QDRANT_TOPIC_INDEXES, + QDRANT_TOPICS_PARAM_MAP, RESOURCES_COLLECTION_NAME, + TOPICS_COLLECTION_NAME, ) from vector_search.encoders.utils import dense_encoder @@ -83,48 +91,29 @@ def create_qdrant_collections(force_recreate): force_recreate (bool): Whether to recreate the collections even if they already exist """ + + collections = [ + RESOURCES_COLLECTION_NAME, + CONTENT_FILES_COLLECTION_NAME, + TOPICS_COLLECTION_NAME, + ] + for collection_name in collections: + create_qdrant_collection(collection_name, force_recreate) + + update_qdrant_indexes() + + +def create_qdrant_collection(collection_name, force_recreate): + """ + Create or recreate a QDrant collection + """ client = qdrant_client() - resources_collection_name = RESOURCES_COLLECTION_NAME - content_files_collection_name = CONTENT_FILES_COLLECTION_NAME encoder = dense_encoder() # True if either of the collections were recreated - - if ( - not client.collection_exists(collection_name=resources_collection_name) - or force_recreate - ): - client.delete_collection(resources_collection_name) + if not client.collection_exists(collection_name=collection_name) or force_recreate: + client.delete_collection(collection_name) client.recreate_collection( - collection_name=resources_collection_name, - on_disk_payload=True, - vectors_config={ - encoder.model_short_name(): models.VectorParams( - size=encoder.dim(), distance=models.Distance.COSINE - ), - }, - replication_factor=2, - shard_number=6, - strict_mode_config=models.StrictModeConfig( - enabled=True, - unindexed_filtering_retrieve=False, - unindexed_filtering_update=False, - ), - sparse_vectors_config=client.get_fastembed_sparse_vector_params(), - optimizers_config=models.OptimizersConfigDiff(default_segment_number=2), - quantization_config=models.BinaryQuantization( - binary=models.BinaryQuantizationConfig( - always_ram=True, - ), - ), - ) - - if ( - not client.collection_exists(collection_name=content_files_collection_name) - or force_recreate - ): - client.delete_collection(content_files_collection_name) - client.recreate_collection( - collection_name=content_files_collection_name, + collection_name=collection_name, on_disk_payload=True, vectors_config={ encoder.model_short_name(): models.VectorParams( @@ -146,7 +135,6 @@ def create_qdrant_collections(force_recreate): ), ), ) - update_qdrant_indexes() def update_qdrant_indexes(): @@ -158,6 +146,7 @@ def update_qdrant_indexes(): for index in [ (QDRANT_LEARNING_RESOURCE_INDEXES, RESOURCES_COLLECTION_NAME), (QDRANT_CONTENT_FILE_INDEXES, CONTENT_FILES_COLLECTION_NAME), + (QDRANT_TOPIC_INDEXES, TOPICS_COLLECTION_NAME), ]: indexes = index[0] collection_name = index[1] @@ -188,6 +177,60 @@ def vector_point_id(readable_id): return str(uuid.uuid5(uuid.NAMESPACE_DNS, readable_id)) +def embed_topics(): + """ + Embed and store new (sub)topics and remove non-existent ones from Qdrant + """ + client = qdrant_client() + create_qdrant_collections(force_recreate=False) + indexed_count = client.count(collection_name=TOPICS_COLLECTION_NAME).count + + topic_names = set( + LearningResourceTopic.objects.filter( + Q(parent=None) | Q(parent__isnull=False) + ).values_list("name", flat=True) + ) + + if indexed_count > 0: + existing = vector_search( + query_string="", + params={}, + search_collection=TOPICS_COLLECTION_NAME, + limit=indexed_count, + ) + indexed_topic_names = {hit["name"] for hit in existing["hits"]} + else: + indexed_topic_names = set() + + new_topics = topic_names - indexed_topic_names + remove_topics = indexed_topic_names - topic_names + for remove_topic in remove_topics: + remove_points_matching_params( + {"name": remove_topic}, collection_name=TOPICS_COLLECTION_NAME + ) + + docs = [] + metadata = [] + ids = [] + + filtered_topics = LearningResourceTopic.objects.filter(name__in=new_topics) + + for topic in filtered_topics: + docs.append(topic.name) + metadata.append( + { + "name": topic.name, + } + ) + ids.append(str(topic.topic_uuid)) + if len(docs) > 0: + encoder = dense_encoder() + embeddings = encoder.embed_documents(docs) + vector_name = encoder.model_short_name() + points = points_generator(ids, metadata, embeddings, vector_name) + client.upload_points(TOPICS_COLLECTION_NAME, points=points, wait=False) + + def _chunk_documents(encoder, texts, metadatas): # chunk the documents. use semantic chunking if enabled chunk_params = { @@ -757,6 +800,8 @@ def qdrant_query_conditions(params, collection_name=RESOURCES_COLLECTION_NAME): conditions = [] if collection_name == RESOURCES_COLLECTION_NAME: QDRANT_PARAM_MAP = QDRANT_RESOURCE_PARAM_MAP + elif collection_name == TOPICS_COLLECTION_NAME: + QDRANT_PARAM_MAP = QDRANT_TOPICS_PARAM_MAP else: QDRANT_PARAM_MAP = QDRANT_CONTENT_FILE_PARAM_MAP if not params: diff --git a/vector_search/utils_test.py b/vector_search/utils_test.py index 346ccda029..c4e48cc17b 100644 --- a/vector_search/utils_test.py +++ b/vector_search/utils_test.py @@ -1,4 +1,5 @@ from decimal import Decimal +from unittest.mock import MagicMock import pytest from django.conf import settings @@ -10,6 +11,7 @@ LearningResourceFactory, LearningResourcePriceFactory, LearningResourceRunFactory, + LearningResourceTopicFactory, ) from learning_resources.models import LearningResource from learning_resources.serializers import LearningResourceMetadataDisplaySerializer @@ -35,6 +37,7 @@ _embed_course_metadata_as_contentfile, create_qdrant_collections, embed_learning_resources, + embed_topics, filter_existing_qdrant_points, qdrant_query_conditions, should_generate_content_embeddings, @@ -915,3 +918,64 @@ def test_update_qdrant_indexes_updates_mismatched_field_type(mocker): for index_field in QDRANT_CONTENT_FILE_INDEXES ] mock_client.create_payload_index.assert_has_calls(expected_calls, any_order=True) + + +def test_embed_topics_no_new_topics(mocker): + """ + Test embed_topics when there are no new topics to embed + """ + mock_client = MagicMock() + mock_qdrant_client = mocker.patch("vector_search.utils.qdrant_client") + mock_qdrant_client.return_value = mock_client + mock_client.count.return_value.count = 1 + mock_vector_search = mocker.patch("vector_search.utils.vector_search") + mock_vector_search.return_value = {"hits": [{"name": "topic1"}]} + LearningResourceTopicFactory.create(name="topic1", parent=None) + mock_remove_points_matching_params = mocker.patch( + "vector_search.utils.remove_points_matching_params" + ) + embed_topics() + mock_remove_points_matching_params.assert_not_called() + mock_client.upload_points.assert_not_called() + + +def test_embed_topics_new_topics(mocker): + """ + Test embed_topics when there are new topics + """ + mock_client = MagicMock() + mock_qdrant_client = mocker.patch("vector_search.utils.qdrant_client") + mock_qdrant_client.return_value = mock_client + mock_client.count.return_value.count = 1 + mock_vector_search = mocker.patch("vector_search.utils.vector_search") + mock_vector_search.return_value = {"hits": [{"name": "topic1"}]} + LearningResourceTopicFactory.create(name="topic1", parent=None) + LearningResourceTopicFactory.create(name="topic2", parent=None) + LearningResourceTopicFactory.create(name="topic3", parent=None) + mocker.patch("vector_search.utils.remove_points_matching_params") + embed_topics() + mock_client.upload_points.assert_called_once() + assert len(list(mock_client.upload_points.mock_calls[0][2]["points"])) == 2 + + +def test_embed_topics_remove_topics(mocker): + """ + Test embed_topics when there are topics to remove + """ + mock_client = MagicMock() + mock_qdrant_client = mocker.patch("vector_search.utils.qdrant_client") + mock_qdrant_client.return_value = mock_client + mock_client.count.return_value.count = 1 + mock_vector_search = mocker.patch("vector_search.utils.vector_search") + mock_vector_search.return_value = {"hits": [{"name": "remove-topic"}]} + + LearningResourceTopicFactory.create(name="topic2", parent=None) + LearningResourceTopicFactory.create(name="topic3", parent=None) + mock_remove_points_matching_params = mocker.patch( + "vector_search.utils.remove_points_matching_params" + ) + embed_topics() + mock_remove_points_matching_params.assert_called_once() + assert ( + mock_remove_points_matching_params.mock_calls[0][1][0]["name"] == "remove-topic" + )