Skip to content

Commit 4e5ac14

Browse files
[Feat] Implement keyword search in Qdrant
This commit implements keyword search in Qdrant. Signed-off-by: Varsha Prasad Narsing <[email protected]>
1 parent a6e2c18 commit 4e5ac14

File tree

4 files changed

+148
-23
lines changed

4 files changed

+148
-23
lines changed

llama_stack/providers/remote/vector_io/qdrant/qdrant.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,35 @@ async def query_vector(self, embedding: NDArray, k: int, score_threshold: float)
128128

129129
return QueryChunksResponse(chunks=chunks, scores=scores)
130130

131-
async def query_keyword(
132-
self,
133-
query_string: str,
134-
k: int,
135-
score_threshold: float,
136-
) -> QueryChunksResponse:
137-
raise NotImplementedError("Keyword search is not supported in Qdrant")
131+
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
132+
results = (
133+
await self.client.query_points(
134+
collection_name=self.collection_name,
135+
query_filter=models.Filter(
136+
must=[models.FieldCondition(key="chunk_content.content", match=models.MatchText(text=query_string))]
137+
),
138+
limit=k,
139+
with_payload=True,
140+
with_vectors=False,
141+
score_threshold=score_threshold,
142+
)
143+
).points
144+
145+
chunks, scores = [], []
146+
for point in results:
147+
assert isinstance(point, models.ScoredPoint)
148+
assert point.payload is not None
149+
150+
try:
151+
chunk = Chunk(**point.payload["chunk_content"])
152+
except Exception:
153+
log.exception("Failed to parse chunk")
154+
continue
155+
156+
chunks.append(chunk)
157+
scores.append(point.score)
158+
159+
return QueryChunksResponse(chunks=chunks, scores=scores)
138160

139161
async def query_hybrid(
140162
self,

tests/integration/vector_io/test_openai_vector_stores.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
5555
],
5656
"keyword": [
5757
"inline::sqlite-vec",
58+
"inline::qdrant",
59+
"remote::qdrant",
5860
"remote::milvus",
5961
],
6062
"hybrid": [

tests/unit/providers/vector_io/conftest.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
import os
78
import random
9+
import tempfile
810

911
import numpy as np
1012
import pytest
@@ -17,7 +19,7 @@
1719
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
1820
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
1921
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
20-
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
22+
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
2123
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
2224
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
2325
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
@@ -29,7 +31,7 @@
2931
MILVUS_ALIAS = "test_milvus"
3032

3133

32-
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
34+
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "qdrant"])
3335
def vector_provider(request):
3436
return request.param
3537

@@ -283,40 +285,53 @@ async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_d
283285

284286

285287
@pytest.fixture
286-
def qdrant_vec_db_path(tmp_path_factory):
288+
def qdrant_vec_db_path(tmp_path):
289+
"""Use tmp_path with additional isolation to ensure unique path per test."""
287290
import uuid
288291

289-
db_path = str(tmp_path_factory.getbasetemp() / f"test_qdrant_{uuid.uuid4()}.db")
290-
return db_path
292+
# Create a completely isolated temporary directory
293+
temp_dir = tempfile.mkdtemp(prefix=f"qdrant_test_{uuid.uuid4()}_")
294+
return temp_dir
291295

292296

293297
@pytest.fixture
294298
async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension):
299+
import shutil
295300
import uuid
296301

297-
config = QdrantVectorIOConfig(
298-
db_path=qdrant_vec_db_path,
302+
config = InlineQdrantVectorIOConfig(
303+
path=qdrant_vec_db_path,
299304
kvstore=SqliteKVStoreConfig(),
300305
)
301306
adapter = QdrantVectorIOAdapter(
302307
config=config,
303308
inference_api=mock_inference_api,
304309
files_api=None,
305310
)
306-
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
311+
312+
original_initialize = adapter.initialize
313+
314+
async def safe_initialize():
315+
if not hasattr(adapter, "_initialized") or not adapter._initialized:
316+
await original_initialize()
317+
adapter._initialized = True
318+
319+
adapter.initialize = safe_initialize
307320
await adapter.initialize()
308-
await adapter.register_vector_db(
309-
VectorDB(
310-
identifier=collection_id,
311-
provider_id="test_provider",
312-
embedding_model="test_model",
313-
embedding_dimension=embedding_dimension,
314-
)
315-
)
321+
322+
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
316323
adapter.test_collection_id = collection_id
324+
adapter._test_db_path = qdrant_vec_db_path
317325
yield adapter
326+
318327
await adapter.shutdown()
319328

329+
try:
330+
if os.path.exists(qdrant_vec_db_path):
331+
shutil.rmtree(qdrant_vec_db_path, ignore_errors=True)
332+
except Exception:
333+
pass
334+
320335

321336
@pytest.fixture
322337
async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):

tests/unit/providers/vector_io/test_qdrant.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,89 @@ async def test_qdrant_register_and_unregister_vector_db(
136136
await qdrant_adapter.unregister_vector_db(vector_db_id)
137137
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
138138
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
139+
140+
141+
# Keyword search tests
142+
async def test_query_chunks_keyword_search(qdrant_vec_index, sample_chunks, sample_embeddings):
143+
"""Test keyword search functionality in Qdrant."""
144+
await qdrant_vec_index.add_chunks(sample_chunks, sample_embeddings)
145+
query_string = "Sentence 5"
146+
response = await qdrant_vec_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
147+
148+
assert isinstance(response, QueryChunksResponse)
149+
assert len(response.chunks) > 0, f"Expected some chunks, but got {len(response.chunks)}"
150+
151+
non_existent_query_str = "blablabla"
152+
response_no_results = await qdrant_vec_index.query_keyword(
153+
query_string=non_existent_query_str, k=1, score_threshold=0.0
154+
)
155+
156+
assert isinstance(response_no_results, QueryChunksResponse)
157+
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
158+
159+
160+
async def test_query_chunks_keyword_search_k_greater_than_results(qdrant_vec_index, sample_chunks, sample_embeddings):
161+
"""Test keyword search when k is greater than available results."""
162+
await qdrant_vec_index.add_chunks(sample_chunks, sample_embeddings)
163+
164+
query_str = "Sentence 1 from document 0" # Should match only one chunk
165+
response = await qdrant_vec_index.query_keyword(k=5, score_threshold=0.0, query_string=query_str)
166+
167+
assert isinstance(response, QueryChunksResponse)
168+
assert 0 < len(response.chunks) < 5, f"Expected results between [1, 4], got {len(response.chunks)}"
169+
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"
170+
171+
172+
async def test_query_chunks_keyword_search_score_threshold(qdrant_vec_index, sample_chunks, sample_embeddings):
173+
"""Test keyword search with score threshold filtering."""
174+
await qdrant_vec_index.add_chunks(sample_chunks, sample_embeddings)
175+
176+
query_string = "Sentence 5"
177+
178+
# Test with low threshold (should return results)
179+
response_low_threshold = await qdrant_vec_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
180+
assert len(response_low_threshold.chunks) > 0
181+
182+
# Test with negative threshold (should return results since scores are 0.0)
183+
response_negative_threshold = await qdrant_vec_index.query_keyword(
184+
query_string=query_string, k=3, score_threshold=-1.0
185+
)
186+
assert len(response_negative_threshold.chunks) > 0
187+
188+
189+
async def test_query_chunks_keyword_search_edge_cases(qdrant_vec_index, sample_chunks, sample_embeddings):
190+
"""Test keyword search edge cases."""
191+
await qdrant_vec_index.add_chunks(sample_chunks, sample_embeddings)
192+
193+
# Test with empty string
194+
response_empty = await qdrant_vec_index.query_keyword(query_string="", k=3, score_threshold=0.0)
195+
assert isinstance(response_empty, QueryChunksResponse)
196+
197+
# Test with very long query string
198+
long_query = "a" * 1000
199+
response_long = await qdrant_vec_index.query_keyword(query_string=long_query, k=3, score_threshold=0.0)
200+
assert isinstance(response_long, QueryChunksResponse)
201+
202+
# Test with special characters
203+
special_query = "!@#$%^&*()_+-=[]{}|;':\",./<>?"
204+
response_special = await qdrant_vec_index.query_keyword(query_string=special_query, k=3, score_threshold=0.0)
205+
assert isinstance(response_special, QueryChunksResponse)
206+
207+
208+
async def test_query_chunks_keyword_search_metadata_preservation(
209+
qdrant_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata
210+
):
211+
"""Test that keyword search preserves chunk metadata."""
212+
await qdrant_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata)
213+
214+
query_string = "Sentence 0"
215+
response = await qdrant_vec_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
216+
217+
assert len(response.chunks) > 0
218+
for chunk in response.chunks:
219+
# Check that metadata is preserved
220+
assert hasattr(chunk, "metadata") or hasattr(chunk, "chunk_metadata")
221+
if hasattr(chunk, "chunk_metadata") and chunk.chunk_metadata:
222+
assert chunk.chunk_metadata.document_id is not None
223+
assert chunk.chunk_metadata.chunk_id is not None
224+
assert chunk.chunk_metadata.source is not None

0 commit comments

Comments
 (0)