Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions learning_resources/etl/loaders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ def mock_duplicates(mocker):
)


@pytest.fixture
def mock_get_similar_topics_qdrant(mocker):
mocker.patch(
"learning_resources_search.plugins.get_similar_topics_qdrant",
return_value=["topic1", "topic2"],
)


@pytest.fixture(autouse=True)
def mock_upsert_tasks(mocker):
"""Mock out the upsert task helpers"""
Expand Down Expand Up @@ -1465,9 +1473,10 @@ def test_load_video(mocker, video_exists, is_published, pass_topics):
assert getattr(result, key) == value, f"Property {key} should equal {value}"


def test_load_videos():
def test_load_videos(mocker, mock_get_similar_topics_qdrant):
"""Verify that load_videos loads a list of videos"""
assert Video.objects.count() == 0

video_resources = [video.learning_resource for video in VideoFactory.build_batch(5)]
videos_data = [
{
Expand All @@ -1486,13 +1495,14 @@ def test_load_videos():


@pytest.mark.parametrize("playlist_exists", [True, False])
def test_load_playlist(mocker, playlist_exists):
def test_load_playlist(mocker, playlist_exists, mock_get_similar_topics_qdrant):
"""Test load_playlist"""
expected_topics = [{"name": "Biology"}, {"name": "Physics"}]
[
LearningResourceTopicFactory.create(name=topic["name"])
for topic in expected_topics
]

mock_most_common_topics = mocker.patch(
"learning_resources.etl.loaders.most_common_topics",
return_value=expected_topics,
Expand Down Expand Up @@ -1904,7 +1914,7 @@ def test_course_with_unpublished_force_ingest_is_test_mode():


@pytest.mark.django_db
def test_load_articles(mocker, climate_platform):
def test_load_articles(mocker, climate_platform, mock_get_similar_topics_qdrant):
articles_data = [
{
"title": "test",
Expand Down
72 changes: 65 additions & 7 deletions learning_resources_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
adjust_search_for_percolator,
document_percolated_actions,
)
from vector_search.constants import RESOURCES_COLLECTION_NAME
from vector_search.constants import (
RESOURCES_COLLECTION_NAME,
TOPICS_COLLECTION_NAME,
)
from vector_search.encoders.utils import dense_encoder

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -830,6 +834,51 @@ def user_subscribed_to_query(
)


def get_similar_topics_qdrant(
resource: LearningResource, value_doc: dict, num_topics: int
) -> list[str]:
from vector_search.utils import qdrant_client, vector_point_id

"""
Get a list of similar topics based on vector similarity

Args:
value_doc (dict):
a document representing the data fields we want to search with
num_topics (int):
number of topics to return
Returns:
list of str:
list of topic values
"""
encoder = dense_encoder()
client = qdrant_client()

response = client.retrieve(
collection_name=RESOURCES_COLLECTION_NAME,
ids=[vector_point_id(resource.readable_id)],
with_vectors=True,
)

embedding_context = "\n".join(
[value_doc[key] for key in value_doc if value_doc[key] is not None]
)
if response and len(response) > 0:
embeddings = response[0].vector.get(encoder.model_short_name())
else:
embeddings = encoder.embed(embedding_context)

return [
hit["name"]
for hit in _qdrant_similar_results(
input_query=embeddings,
num_resources=num_topics,
collection_name=TOPICS_COLLECTION_NAME,
score_threshold=0.2,
)
]


def get_similar_topics(
value_doc: dict, num_topics: int, min_term_freq: int, min_doc_freq: int
) -> list[str]:
Expand Down Expand Up @@ -909,7 +958,12 @@ def get_similar_resources(
)


def _qdrant_similar_results(doc, num_resources):
def _qdrant_similar_results(
input_query,
num_resources=6,
collection_name=RESOURCES_COLLECTION_NAME,
score_threshold=0,
):
"""
Get similar resources from qdrant

Expand All @@ -924,20 +978,19 @@ def _qdrant_similar_results(doc, num_resources):
list of serialized resources
"""
from vector_search.utils import (
dense_encoder,
qdrant_client,
vector_point_id,
)

encoder = dense_encoder()
client = qdrant_client()
return [
hit.payload
for hit in client.query_points(
collection_name=RESOURCES_COLLECTION_NAME,
query=vector_point_id(doc["readable_id"]),
collection_name=collection_name,
query=input_query,
limit=num_resources,
using=encoder.model_short_name(),
score_threshold=score_threshold,
).points
]

Expand All @@ -956,7 +1009,12 @@ def get_similar_resources_qdrant(value_doc: dict, num_resources: int):
list of str:
list of learning resources
"""
hits = _qdrant_similar_results(value_doc, num_resources)
from vector_search.utils import vector_point_id
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to the top of the file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment about circular import above


hits = _qdrant_similar_results(
input_query=vector_point_id(value_doc["readable_id"]),
num_resources=num_resources,
)
return (
LearningResource.objects.for_search_serialization()
.filter(
Expand Down
39 changes: 38 additions & 1 deletion learning_resources_search/api_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Search API function tests"""

from unittest.mock import Mock
from unittest.mock import MagicMock, Mock

import pytest
from freezegun import freeze_time
Expand All @@ -21,6 +21,7 @@
generate_sort_clause,
generate_suggest_clause,
get_similar_topics,
get_similar_topics_qdrant,
percolate_matches_for_document,
relevant_indexes,
)
Expand Down Expand Up @@ -3266,3 +3267,39 @@ def test_dev_mode(dev_mode):
assert construct_search(search_params).to_dict().get("explain")
else:
assert construct_search(search_params).to_dict().get("explain") is None


@pytest.mark.django_db
def test_get_similar_topics_qdrant_uses_cached_embedding(mocker):
"""
Test that get_similar_topics_qdrant uses a cached embedding when available
"""
resource = MagicMock()
resource.readable_id = "test-resource"
value_doc = {"title": "Test Title", "description": "Test Description"}
num_topics = 3

mock_encoder = mocker.patch("learning_resources_search.api.dense_encoder")
encoder_instance = mock_encoder.return_value
encoder_instance.model_short_name.return_value = "test-model"
encoder_instance.embed.return_value = [0.1, 0.2, 0.3]

mock_client = mocker.patch("vector_search.utils.qdrant_client")
client_instance = mock_client.return_value

# Simulate a cached embedding in the response
client_instance.retrieve.return_value = [
MagicMock(vector={"test-model": [0.9, 0.8, 0.7]})
]

mocker.patch(
"learning_resources_search.api._qdrant_similar_results",
return_value=[{"name": "topic1"}, {"name": "topic2"}],
)

result = get_similar_topics_qdrant(resource, value_doc, num_topics)

# Assert that embed was NOT called (cached embedding used)
encoder_instance.embed.assert_not_called()
# Assert that the result is as expected
assert result == ["topic1", "topic2"]
7 changes: 3 additions & 4 deletions learning_resources_search/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.conf import settings as django_settings

from learning_resources_search import tasks
from learning_resources_search.api import get_similar_topics
from learning_resources_search.api import get_similar_topics_qdrant
from learning_resources_search.constants import (
COURSE_TYPE,
PERCOLATE_INDEX_TYPE,
Expand Down Expand Up @@ -125,11 +125,10 @@ def resource_similar_topics(self, resource) -> list[dict]:
"full_description": resource.full_description,
}

topic_names = get_similar_topics(
topic_names = get_similar_topics_qdrant(
resource,
text_doc,
settings.OPEN_VIDEO_MAX_TOPICS,
settings.OPEN_VIDEO_MIN_TERM_FREQ,
settings.OPEN_VIDEO_MIN_DOC_FREQ,
)
return [{"name": topic_name} for topic_name in topic_names]

Expand Down
5 changes: 2 additions & 3 deletions learning_resources_search/plugins_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,18 @@ def test_resource_similar_topics(mocker, settings):
"""The plugin function should return expected topics for a resource"""
expected_topics = ["topic1", "topic2"]
mock_similar_topics = mocker.patch(
"learning_resources_search.plugins.get_similar_topics",
"learning_resources_search.plugins.get_similar_topics_qdrant",
return_value=expected_topics,
)
resource = LearningResourceFactory.create()
topics = SearchIndexPlugin().resource_similar_topics(resource)
assert topics == [{"name": topic} for topic in expected_topics]
mock_similar_topics.assert_called_once_with(
resource,
{
"title": resource.title,
"description": resource.description,
"full_description": resource.full_description,
},
settings.OPEN_VIDEO_MAX_TOPICS,
settings.OPEN_VIDEO_MIN_TERM_FREQ,
settings.OPEN_VIDEO_MIN_DOC_FREQ,
)
2 changes: 1 addition & 1 deletion main/settings_course_etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
# course catalog video etl settings
OPEN_VIDEO_DATA_BRANCH = get_string("OPEN_VIDEO_DATA_BRANCH", "master")
OPEN_VIDEO_USER_LIST_OWNER = get_string("OPEN_VIDEO_USER_LIST_OWNER", None)
OPEN_VIDEO_MAX_TOPICS = get_int("OPEN_VIDEO_MAX_TOPICS", 3)
OPEN_VIDEO_MAX_TOPICS = get_int("OPEN_VIDEO_MAX_TOPICS", 2)
OPEN_VIDEO_MIN_TERM_FREQ = get_int("OPEN_VIDEO_MIN_TERM_FREQ", 1)
OPEN_VIDEO_MIN_DOC_FREQ = get_int("OPEN_VIDEO_MIN_DOC_FREQ", 15)

Expand Down
9 changes: 9 additions & 0 deletions vector_search/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

RESOURCES_COLLECTION_NAME = f"{settings.QDRANT_BASE_COLLECTION_NAME}.resources"
CONTENT_FILES_COLLECTION_NAME = f"{settings.QDRANT_BASE_COLLECTION_NAME}.content_files"
TOPICS_COLLECTION_NAME = f"{settings.QDRANT_BASE_COLLECTION_NAME}.topics"

QDRANT_CONTENT_FILE_PARAM_MAP = {
"key": "key",
Expand Down Expand Up @@ -43,6 +44,10 @@
}


QDRANT_TOPICS_PARAM_MAP = {
"name": "name",
}

QDRANT_LEARNING_RESOURCE_INDEXES = {
"readable_id": models.PayloadSchemaType.KEYWORD,
"resource_type": models.PayloadSchemaType.KEYWORD,
Expand Down Expand Up @@ -82,3 +87,7 @@
"edx_block_id": models.PayloadSchemaType.KEYWORD,
"url": models.PayloadSchemaType.KEYWORD,
}

QDRANT_TOPIC_INDEXES = {
"name": models.PayloadSchemaType.KEYWORD,
}
27 changes: 27 additions & 0 deletions vector_search/management/commands/sync_topic_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Management command to update or create the topics collection in Qdrant"""

from django.core.management.base import BaseCommand, CommandError

from main.utils import clear_search_cache, now_in_utc
from vector_search.tasks import sync_topics


class Command(BaseCommand):
"""Syncs embeddings for topics in Qdrant"""

help = "update or create the topics collection in Qdrant"

def handle(self, *args, **options): # noqa: ARG002
"""Sync the topics collection"""
task = sync_topics.apply()
self.stdout.write("Waiting on task...")
start = now_in_utc()
error = task.get()
if error:
msg = f"Geenerate embeddings errored: {error}"
raise CommandError(msg)
clear_search_cache()
total_seconds = (now_in_utc() - start).total_seconds()
self.stdout.write(
f"Embeddings generated and stored, took {total_seconds} seconds"
)
14 changes: 13 additions & 1 deletion vector_search/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
chunks,
now_in_utc,
)
from vector_search.utils import embed_learning_resources, remove_qdrant_records
from vector_search.utils import (
embed_learning_resources,
embed_topics,
remove_qdrant_records,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -362,3 +366,11 @@ def remove_run_content_files(run_id):
for ids in chunks(content_file_ids, chunk_size=settings.QDRANT_CHUNK_SIZE)
]
)


@app.task
def sync_topics():
"""
Sync topics to the Qdrant collection
"""
embed_topics()
Loading
Loading