Skip to content

Commit e1de652

Browse files
authored
Vector based topics tagging for videos (#2649)
* some refactoring and de-duping * adding topics collection name * adding topic collection indexes * adding method to embed topics * keeping old topic generation method * adding qdrant based topic generator * fix test * fix tests * adding score thresholding * adding management command to sync topics and adding a cache mecahnism for assignment * fix test * docstring update * update docstrings * update topic query * adding topic sync tests * switch default number of topics to 2 * test for cached embedding * relocating import and making test topics a fixture * fixing mock
1 parent 3bfdf5e commit e1de652

File tree

11 files changed

+320
-60
lines changed

11 files changed

+320
-60
lines changed

learning_resources/etl/loaders_test.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ def mock_duplicates(mocker):
131131
)
132132

133133

134+
@pytest.fixture
135+
def mock_get_similar_topics_qdrant(mocker):
136+
mocker.patch(
137+
"learning_resources_search.plugins.get_similar_topics_qdrant",
138+
return_value=["topic1", "topic2"],
139+
)
140+
141+
134142
@pytest.fixture(autouse=True)
135143
def mock_upsert_tasks(mocker):
136144
"""Mock out the upsert task helpers"""
@@ -1465,9 +1473,10 @@ def test_load_video(mocker, video_exists, is_published, pass_topics):
14651473
assert getattr(result, key) == value, f"Property {key} should equal {value}"
14661474

14671475

1468-
def test_load_videos():
1476+
def test_load_videos(mocker, mock_get_similar_topics_qdrant):
14691477
"""Verify that load_videos loads a list of videos"""
14701478
assert Video.objects.count() == 0
1479+
14711480
video_resources = [video.learning_resource for video in VideoFactory.build_batch(5)]
14721481
videos_data = [
14731482
{
@@ -1486,13 +1495,14 @@ def test_load_videos():
14861495

14871496

14881497
@pytest.mark.parametrize("playlist_exists", [True, False])
1489-
def test_load_playlist(mocker, playlist_exists):
1498+
def test_load_playlist(mocker, playlist_exists, mock_get_similar_topics_qdrant):
14901499
"""Test load_playlist"""
14911500
expected_topics = [{"name": "Biology"}, {"name": "Physics"}]
14921501
[
14931502
LearningResourceTopicFactory.create(name=topic["name"])
14941503
for topic in expected_topics
14951504
]
1505+
14961506
mock_most_common_topics = mocker.patch(
14971507
"learning_resources.etl.loaders.most_common_topics",
14981508
return_value=expected_topics,
@@ -1904,7 +1914,7 @@ def test_course_with_unpublished_force_ingest_is_test_mode():
19041914

19051915

19061916
@pytest.mark.django_db
1907-
def test_load_articles(mocker, climate_platform):
1917+
def test_load_articles(mocker, climate_platform, mock_get_similar_topics_qdrant):
19081918
articles_data = [
19091919
{
19101920
"title": "test",

learning_resources_search/api.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
adjust_search_for_percolator,
3636
document_percolated_actions,
3737
)
38-
from vector_search.constants import RESOURCES_COLLECTION_NAME
38+
from vector_search.constants import (
39+
RESOURCES_COLLECTION_NAME,
40+
TOPICS_COLLECTION_NAME,
41+
)
42+
from vector_search.encoders.utils import dense_encoder
3943

4044
log = logging.getLogger(__name__)
4145

@@ -830,6 +834,51 @@ def user_subscribed_to_query(
830834
)
831835

832836

837+
def get_similar_topics_qdrant(
838+
resource: LearningResource, value_doc: dict, num_topics: int
839+
) -> list[str]:
840+
from vector_search.utils import qdrant_client, vector_point_id
841+
842+
"""
843+
Get a list of similar topics based on vector similarity
844+
845+
Args:
846+
value_doc (dict):
847+
a document representing the data fields we want to search with
848+
num_topics (int):
849+
number of topics to return
850+
Returns:
851+
list of str:
852+
list of topic values
853+
"""
854+
encoder = dense_encoder()
855+
client = qdrant_client()
856+
857+
response = client.retrieve(
858+
collection_name=RESOURCES_COLLECTION_NAME,
859+
ids=[vector_point_id(resource.readable_id)],
860+
with_vectors=True,
861+
)
862+
863+
embedding_context = "\n".join(
864+
[value_doc[key] for key in value_doc if value_doc[key] is not None]
865+
)
866+
if response and len(response) > 0:
867+
embeddings = response[0].vector.get(encoder.model_short_name())
868+
else:
869+
embeddings = encoder.embed(embedding_context)
870+
871+
return [
872+
hit["name"]
873+
for hit in _qdrant_similar_results(
874+
input_query=embeddings,
875+
num_resources=num_topics,
876+
collection_name=TOPICS_COLLECTION_NAME,
877+
score_threshold=0.2,
878+
)
879+
]
880+
881+
833882
def get_similar_topics(
834883
value_doc: dict, num_topics: int, min_term_freq: int, min_doc_freq: int
835884
) -> list[str]:
@@ -909,7 +958,12 @@ def get_similar_resources(
909958
)
910959

911960

912-
def _qdrant_similar_results(doc, num_resources):
961+
def _qdrant_similar_results(
962+
input_query,
963+
num_resources=6,
964+
collection_name=RESOURCES_COLLECTION_NAME,
965+
score_threshold=0,
966+
):
913967
"""
914968
Get similar resources from qdrant
915969
@@ -924,20 +978,19 @@ def _qdrant_similar_results(doc, num_resources):
924978
list of serialized resources
925979
"""
926980
from vector_search.utils import (
927-
dense_encoder,
928981
qdrant_client,
929-
vector_point_id,
930982
)
931983

932984
encoder = dense_encoder()
933985
client = qdrant_client()
934986
return [
935987
hit.payload
936988
for hit in client.query_points(
937-
collection_name=RESOURCES_COLLECTION_NAME,
938-
query=vector_point_id(doc["readable_id"]),
989+
collection_name=collection_name,
990+
query=input_query,
939991
limit=num_resources,
940992
using=encoder.model_short_name(),
993+
score_threshold=score_threshold,
941994
).points
942995
]
943996

@@ -956,7 +1009,12 @@ def get_similar_resources_qdrant(value_doc: dict, num_resources: int):
9561009
list of str:
9571010
list of learning resources
9581011
"""
959-
hits = _qdrant_similar_results(value_doc, num_resources)
1012+
from vector_search.utils import vector_point_id
1013+
1014+
hits = _qdrant_similar_results(
1015+
input_query=vector_point_id(value_doc["readable_id"]),
1016+
num_resources=num_resources,
1017+
)
9601018
return (
9611019
LearningResource.objects.for_search_serialization()
9621020
.filter(

learning_resources_search/api_test.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Search API function tests"""
22

3-
from unittest.mock import Mock
3+
from unittest.mock import MagicMock, Mock
44

55
import pytest
66
from freezegun import freeze_time
@@ -21,6 +21,7 @@
2121
generate_sort_clause,
2222
generate_suggest_clause,
2323
get_similar_topics,
24+
get_similar_topics_qdrant,
2425
percolate_matches_for_document,
2526
relevant_indexes,
2627
)
@@ -3266,3 +3267,39 @@ def test_dev_mode(dev_mode):
32663267
assert construct_search(search_params).to_dict().get("explain")
32673268
else:
32683269
assert construct_search(search_params).to_dict().get("explain") is None
3270+
3271+
3272+
@pytest.mark.django_db
3273+
def test_get_similar_topics_qdrant_uses_cached_embedding(mocker):
3274+
"""
3275+
Test that get_similar_topics_qdrant uses a cached embedding when available
3276+
"""
3277+
resource = MagicMock()
3278+
resource.readable_id = "test-resource"
3279+
value_doc = {"title": "Test Title", "description": "Test Description"}
3280+
num_topics = 3
3281+
3282+
mock_encoder = mocker.patch("learning_resources_search.api.dense_encoder")
3283+
encoder_instance = mock_encoder.return_value
3284+
encoder_instance.model_short_name.return_value = "test-model"
3285+
encoder_instance.embed.return_value = [0.1, 0.2, 0.3]
3286+
3287+
mock_client = mocker.patch("vector_search.utils.qdrant_client")
3288+
client_instance = mock_client.return_value
3289+
3290+
# Simulate a cached embedding in the response
3291+
client_instance.retrieve.return_value = [
3292+
MagicMock(vector={"test-model": [0.9, 0.8, 0.7]})
3293+
]
3294+
3295+
mocker.patch(
3296+
"learning_resources_search.api._qdrant_similar_results",
3297+
return_value=[{"name": "topic1"}, {"name": "topic2"}],
3298+
)
3299+
3300+
result = get_similar_topics_qdrant(resource, value_doc, num_topics)
3301+
3302+
# Assert that embed was NOT called (cached embedding used)
3303+
encoder_instance.embed.assert_not_called()
3304+
# Assert that the result is as expected
3305+
assert result == ["topic1", "topic2"]

learning_resources_search/plugins.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django.conf import settings as django_settings
88

99
from learning_resources_search import tasks
10-
from learning_resources_search.api import get_similar_topics
10+
from learning_resources_search.api import get_similar_topics_qdrant
1111
from learning_resources_search.constants import (
1212
COURSE_TYPE,
1313
PERCOLATE_INDEX_TYPE,
@@ -125,11 +125,10 @@ def resource_similar_topics(self, resource) -> list[dict]:
125125
"full_description": resource.full_description,
126126
}
127127

128-
topic_names = get_similar_topics(
128+
topic_names = get_similar_topics_qdrant(
129+
resource,
129130
text_doc,
130131
settings.OPEN_VIDEO_MAX_TOPICS,
131-
settings.OPEN_VIDEO_MIN_TERM_FREQ,
132-
settings.OPEN_VIDEO_MIN_DOC_FREQ,
133132
)
134133
return [{"name": topic_name} for topic_name in topic_names]
135134

learning_resources_search/plugins_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,18 @@ def test_resource_similar_topics(mocker, settings):
128128
"""The plugin function should return expected topics for a resource"""
129129
expected_topics = ["topic1", "topic2"]
130130
mock_similar_topics = mocker.patch(
131-
"learning_resources_search.plugins.get_similar_topics",
131+
"learning_resources_search.plugins.get_similar_topics_qdrant",
132132
return_value=expected_topics,
133133
)
134134
resource = LearningResourceFactory.create()
135135
topics = SearchIndexPlugin().resource_similar_topics(resource)
136136
assert topics == [{"name": topic} for topic in expected_topics]
137137
mock_similar_topics.assert_called_once_with(
138+
resource,
138139
{
139140
"title": resource.title,
140141
"description": resource.description,
141142
"full_description": resource.full_description,
142143
},
143144
settings.OPEN_VIDEO_MAX_TOPICS,
144-
settings.OPEN_VIDEO_MIN_TERM_FREQ,
145-
settings.OPEN_VIDEO_MIN_DOC_FREQ,
146145
)

main/settings_course_etl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
# course catalog video etl settings
9797
OPEN_VIDEO_DATA_BRANCH = get_string("OPEN_VIDEO_DATA_BRANCH", "master")
9898
OPEN_VIDEO_USER_LIST_OWNER = get_string("OPEN_VIDEO_USER_LIST_OWNER", None)
99-
OPEN_VIDEO_MAX_TOPICS = get_int("OPEN_VIDEO_MAX_TOPICS", 3)
99+
OPEN_VIDEO_MAX_TOPICS = get_int("OPEN_VIDEO_MAX_TOPICS", 2)
100100
OPEN_VIDEO_MIN_TERM_FREQ = get_int("OPEN_VIDEO_MIN_TERM_FREQ", 1)
101101
OPEN_VIDEO_MIN_DOC_FREQ = get_int("OPEN_VIDEO_MIN_DOC_FREQ", 15)
102102

vector_search/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
RESOURCES_COLLECTION_NAME = f"{settings.QDRANT_BASE_COLLECTION_NAME}.resources"
55
CONTENT_FILES_COLLECTION_NAME = f"{settings.QDRANT_BASE_COLLECTION_NAME}.content_files"
6+
TOPICS_COLLECTION_NAME = f"{settings.QDRANT_BASE_COLLECTION_NAME}.topics"
67

78
QDRANT_CONTENT_FILE_PARAM_MAP = {
89
"key": "key",
@@ -43,6 +44,10 @@
4344
}
4445

4546

47+
QDRANT_TOPICS_PARAM_MAP = {
48+
"name": "name",
49+
}
50+
4651
QDRANT_LEARNING_RESOURCE_INDEXES = {
4752
"readable_id": models.PayloadSchemaType.KEYWORD,
4853
"resource_type": models.PayloadSchemaType.KEYWORD,
@@ -82,3 +87,7 @@
8287
"edx_block_id": models.PayloadSchemaType.KEYWORD,
8388
"url": models.PayloadSchemaType.KEYWORD,
8489
}
90+
91+
QDRANT_TOPIC_INDEXES = {
92+
"name": models.PayloadSchemaType.KEYWORD,
93+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Management command to update or create the topics collection in Qdrant"""
2+
3+
from django.core.management.base import BaseCommand, CommandError
4+
5+
from main.utils import clear_search_cache, now_in_utc
6+
from vector_search.tasks import sync_topics
7+
8+
9+
class Command(BaseCommand):
10+
"""Syncs embeddings for topics in Qdrant"""
11+
12+
help = "update or create the topics collection in Qdrant"
13+
14+
def handle(self, *args, **options): # noqa: ARG002
15+
"""Sync the topics collection"""
16+
task = sync_topics.apply()
17+
self.stdout.write("Waiting on task...")
18+
start = now_in_utc()
19+
error = task.get()
20+
if error:
21+
msg = f"Geenerate embeddings errored: {error}"
22+
raise CommandError(msg)
23+
clear_search_cache()
24+
total_seconds = (now_in_utc() - start).total_seconds()
25+
self.stdout.write(
26+
f"Embeddings generated and stored, took {total_seconds} seconds"
27+
)

vector_search/tasks.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@
3232
chunks,
3333
now_in_utc,
3434
)
35-
from vector_search.utils import embed_learning_resources, remove_qdrant_records
35+
from vector_search.utils import (
36+
embed_learning_resources,
37+
embed_topics,
38+
remove_qdrant_records,
39+
)
3640

3741
log = logging.getLogger(__name__)
3842

@@ -362,3 +366,11 @@ def remove_run_content_files(run_id):
362366
for ids in chunks(content_file_ids, chunk_size=settings.QDRANT_CHUNK_SIZE)
363367
]
364368
)
369+
370+
371+
@app.task
372+
def sync_topics():
373+
"""
374+
Sync topics to the Qdrant collection
375+
"""
376+
embed_topics()

0 commit comments

Comments
 (0)