From 8fd7e7aa34e1d330705ed55ada6dd910b26642ed Mon Sep 17 00:00:00 2001 From: MagellaX Date: Wed, 13 Aug 2025 23:31:55 +0530 Subject: [PATCH] feat(lib): add high-level retrieval pipeline with RRF + optional MMR; tests and example --- examples/retrieval_basic.py | 24 +++ src/zeroentropy/lib/retrieval.py | 278 +++++++++++++++++++++++++++++++ tests/test_retrieval.py | 129 ++++++++++++++ 3 files changed, 431 insertions(+) create mode 100644 examples/retrieval_basic.py create mode 100644 src/zeroentropy/lib/retrieval.py create mode 100644 tests/test_retrieval.py diff --git a/examples/retrieval_basic.py b/examples/retrieval_basic.py new file mode 100644 index 0000000..b04050d --- /dev/null +++ b/examples/retrieval_basic.py @@ -0,0 +1,24 @@ +#!/usr/bin/env -S rye run python +import os +from zeroentropy import ZeroEntropy +from zeroentropy.lib.retrieval import search_documents + + +def main() -> None: + client = ZeroEntropy(api_key=os.environ.get("ZEROENTROPY_API_KEY")) + results = search_documents( + client, + query="how to reset password", + collections=["product_docs", "engineering_notes"], + k=5, + include_metadata=True, + diversify=True, + ) + for r in results: + print(f"{r.fused_score:.4f}\t{r.best_collection}\t{r.path}") + + +if __name__ == "__main__": + main() + + diff --git a/src/zeroentropy/lib/retrieval.py b/src/zeroentropy/lib/retrieval.py new file mode 100644 index 0000000..86913c9 --- /dev/null +++ b/src/zeroentropy/lib/retrieval.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing_extensions import Literal + +from .. import ZeroEntropy, AsyncZeroEntropy +from .._models import BaseModel +from ..types.query_top_documents_response import QueryTopDocumentsResponse + + +class SourceRank(BaseModel): + collection_name: str + rank: int + score: float + + +class UnifiedDocumentResult(BaseModel): + path: str + fused_score: float + best_collection: str + file_url: Optional[str] = None + metadata: Optional[Dict[str, Union[str, List[str]]]] = None + source_ranks: List[SourceRank] + + +def _tokenize_for_similarity(text: str) -> List[str]: + lowered = text.lower() + tokens: List[str] = [] + token_chars: List[str] = [] + for ch in lowered: + if ch.isalnum(): + token_chars.append(ch) + else: + if token_chars: + tokens.append("".join(token_chars)) + token_chars = [] + if token_chars: + tokens.append("".join(token_chars)) + # de-duplicate tokens while preserving order + seen: set[str] = set() + unique: List[str] = [] + for t in tokens: + if t not in seen: + unique.append(t) + seen.add(t) + return unique + + +def _jaccard(a_tokens: Sequence[str], b_tokens: Sequence[str]) -> float: + if not a_tokens and not b_tokens: + return 0.0 + a = set(a_tokens) + b = set(b_tokens) + inter = len(a.intersection(b)) + union = len(a.union(b)) + if union == 0: + return 0.0 + return inter / union + + +def _mmr( + candidates: List[UnifiedDocumentResult], + *, + k: int, + lambda_: float, +) -> List[UnifiedDocumentResult]: + if k <= 0 or not candidates: + return [] + k = min(k, len(candidates)) + + # precompute normalized scores and tokenization for a simple similarity based on path + max_score = max(c.fused_score for c in candidates) or 1.0 + norm_scores = {c.path: (c.fused_score / max_score) for c in candidates} + tokens = {c.path: _tokenize_for_similarity(c.path) for c in candidates} + + selected: List[UnifiedDocumentResult] = [] + remaining = candidates.copy() + + while remaining and len(selected) < k: + best: Tuple[float, UnifiedDocumentResult] | None = None + for cand in remaining: + # diversity penalty: similarity to closest already selected item + if selected: + max_sim = max( + _jaccard(tokens[cand.path], tokens[sel.path]) for sel in selected + ) + else: + max_sim = 0.0 + + score = lambda_ * norm_scores[cand.path] - (1.0 - lambda_) * max_sim + if best is None or score > best[0]: + best = (score, cand) + + assert best is not None + _, chosen = best + selected.append(chosen) + remaining = [c for c in remaining if c.path != chosen.path] + + return selected + + +def _rrf(scores_by_collection: Dict[str, List[Tuple[str, float, Optional[str], Optional[Dict[str, Union[str, List[str]]]]]]]) -> List[UnifiedDocumentResult]: + # Reciprocal Rank Fusion across collections. + K = 60 # standard constant stabilizer + + # aggregate by path + by_path: Dict[str, Dict[str, Any]] = {} + for collection, ranked in scores_by_collection.items(): + for idx, (path, score, file_url, metadata) in enumerate(ranked, start=1): + entry = by_path.setdefault( + path, + { + "fused": 0.0, + "best_collection": collection, + "best_rank": idx, + "file_url": file_url, + "metadata": metadata, + "sources": [], + }, + ) + entry["fused"] += 1.0 / (K + idx) + if idx < entry["best_rank"]: + entry["best_rank"] = idx + entry["best_collection"] = collection + # prefer latest non-null file_url/metadata for best source + if file_url is not None: + entry["file_url"] = file_url + if metadata is not None: + entry["metadata"] = metadata + entry["sources"].append(SourceRank(collection_name=collection, rank=idx, score=score)) + + fused: List[UnifiedDocumentResult] = [] + for path, data in by_path.items(): + fused.append( + UnifiedDocumentResult( + path=path, + fused_score=data["fused"], + best_collection=data["best_collection"], + file_url=data.get("file_url"), + metadata=data.get("metadata"), + source_ranks=data["sources"], + ) + ) + + fused.sort(key=lambda r: r.fused_score, reverse=True) + return fused + + +def search_documents( + client: ZeroEntropy, + *, + query: str, + collections: Sequence[str], + k: int = 20, + filter: Optional[Mapping[str, object]] = None, + include_metadata: bool = False, + latency_mode: Literal["low", "high"] = "low", + strategy: Literal["rrf", "concat"] = "rrf", + diversify: bool = True, + mmr_lambda: float = 0.5, +) -> List[UnifiedDocumentResult]: + """High-level retrieval that queries multiple collections and fuses results. + + This function performs multi-collection search and merges results using + Reciprocal Rank Fusion (RRF). Optionally applies a simple MMR diversification + step based on path similarity to improve diversity. + """ + + if not collections: + return [] + + scores_by_collection: Dict[str, List[Tuple[str, float, Optional[str], Optional[Dict[str, Union[str, List[str]]]]]]] = {} + for collection_name in collections: + resp: QueryTopDocumentsResponse = client.queries.top_documents( + collection_name=collection_name, + k=k, + query=query, + filter=dict(filter) if filter is not None else None, + include_metadata=include_metadata, + latency_mode=latency_mode, + ) + + ranked: List[Tuple[str, float, Optional[str], Optional[Dict[str, Union[str, List[str]]]]]] = [] + for r in resp.results: + ranked.append((r.path, r.score, getattr(r, "file_url", None), getattr(r, "metadata", None))) + scores_by_collection[collection_name] = ranked + + if strategy == "concat": + # simply concatenate, keeping per-collection order + concatenated: List[UnifiedDocumentResult] = [] + for c in collections: + ranked = scores_by_collection.get(c, []) + for idx, (path, score, file_url, metadata) in enumerate(ranked, start=1): + concatenated.append( + UnifiedDocumentResult( + path=path, + fused_score=score, + best_collection=c, + file_url=file_url, + metadata=metadata, + source_ranks=[SourceRank(collection_name=c, rank=idx, score=score)], + ) + ) + concatenated.sort(key=lambda r: r.fused_score, reverse=True) + results = concatenated[:k] + else: + results = _rrf(scores_by_collection) + results = results[: max(k * 2, k)] # keep a bit more for MMR to work with + + if diversify and results: + results = _mmr(results, k=k, lambda_=mmr_lambda) + else: + results = results[:k] + + return results + + +async def asearch_documents( + client: AsyncZeroEntropy, + *, + query: str, + collections: Sequence[str], + k: int = 20, + filter: Optional[Mapping[str, object]] = None, + include_metadata: bool = False, + latency_mode: Literal["low", "high"] = "low", + strategy: Literal["rrf", "concat"] = "rrf", + diversify: bool = True, + mmr_lambda: float = 0.5, +) -> List[UnifiedDocumentResult]: + if not collections: + return [] + + scores_by_collection: Dict[str, List[Tuple[str, float, Optional[str], Optional[Dict[str, Union[str, List[str]]]]]]] = {} + for collection_name in collections: + resp: QueryTopDocumentsResponse = await client.queries.top_documents( + collection_name=collection_name, + k=k, + query=query, + filter=dict(filter) if filter is not None else None, + include_metadata=include_metadata, + latency_mode=latency_mode, + ) + + ranked: List[Tuple[str, float, Optional[str], Optional[Dict[str, Union[str, List[str]]]]]] = [] + for r in resp.results: + ranked.append((r.path, r.score, getattr(r, "file_url", None), getattr(r, "metadata", None))) + scores_by_collection[collection_name] = ranked + + if strategy == "concat": + concatenated: List[UnifiedDocumentResult] = [] + for c in collections: + ranked = scores_by_collection.get(c, []) + for idx, (path, score, file_url, metadata) in enumerate(ranked, start=1): + concatenated.append( + UnifiedDocumentResult( + path=path, + fused_score=score, + best_collection=c, + file_url=file_url, + metadata=metadata, + source_ranks=[SourceRank(collection_name=c, rank=idx, score=score)], + ) + ) + concatenated.sort(key=lambda r: r.fused_score, reverse=True) + results = concatenated[:k] + else: + results = _rrf(scores_by_collection) + results = results[: max(k * 2, k)] + + if diversify and results: + results = _mmr(results, k=k, lambda_=mmr_lambda) + else: + results = results[:k] + + return results + + diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py new file mode 100644 index 0000000..304407f --- /dev/null +++ b/tests/test_retrieval.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import json +from typing import Any, Dict, List + +import httpx +import pytest +from respx import MockRouter # type: ignore[import-not-found] + +from zeroentropy import ZeroEntropy, AsyncZeroEntropy +from zeroentropy.lib.retrieval import ( + UnifiedDocumentResult, + asearch_documents, + search_documents, +) + + +base_url = "http://127.0.0.1:4010" + + +def _mock_top_documents_for_collections(respx_mock: Any, mapping: Dict[str, List[Dict[str, Any]]]) -> None: + """Install a handler that returns different results per collection_name""" + + def handler(request: httpx.Request) -> httpx.Response: + data: Dict[str, Any] = json.loads(request.content.decode("utf-8")) if request.content else {} + collection_name = str(data.get("collection_name", "")) + results = mapping.get(collection_name, []) + return httpx.Response(200, json={"results": results}) + + respx_mock.post("/queries/top-documents").mock(side_effect=handler) + + +class TestRetrieval: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + @pytest.mark.respx(base_url=base_url) + def test_rrf_multi_collection(self, client: ZeroEntropy, respx_mock: Any) -> None: + # Collection A ranks: doc1 (1), doc2 (2) + # Collection B ranks: doc2 (1), doc3 (2) + _mock_top_documents_for_collections( + respx_mock, + { + "A": [ + {"path": "doc1", "score": 0.9, "file_url": "u1", "metadata": {"a": "1"}}, + {"path": "doc2", "score": 0.5, "file_url": "u2", "metadata": {"a": "2"}}, + ], + "B": [ + {"path": "doc2", "score": 0.95, "file_url": "u3", "metadata": {"b": "1"}}, + {"path": "doc3", "score": 0.4, "file_url": "u4", "metadata": {"b": "2"}}, + ], + }, + ) + + results = search_documents( + client, + query="q", + collections=["A", "B"], + k=3, + include_metadata=True, + diversify=False, + ) + + assert [r.path for r in results] == ["doc2", "doc1", "doc3"] + # best_collection for doc2 should be B (rank 1 there) + assert results[0].best_collection == "B" + assert isinstance(results[0], UnifiedDocumentResult) + + @parametrize + @pytest.mark.respx(base_url=base_url) + def test_concat_strategy(self, client: ZeroEntropy, respx_mock: Any) -> None: + _mock_top_documents_for_collections( + respx_mock, + { + "TeamX": [ + {"path": "x1", "score": 0.6, "file_url": "ux1", "metadata": {}}, + ], + "TeamY": [ + {"path": "y1", "score": 0.9, "file_url": "uy1", "metadata": {}}, + ], + }, + ) + + results = search_documents( + client, + query="hello", + collections=["TeamX", "TeamY"], + k=2, + strategy="concat", + diversify=False, + ) + + # Sorted by score descending after concatenation + assert [r.path for r in results] == ["y1", "x1"] + + +class TestAsyncRetrieval: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + @pytest.mark.asyncio + @pytest.mark.respx(base_url=base_url) + async def test_async_rrf_multi_collection(self, async_client: AsyncZeroEntropy, respx_mock: Any) -> None: + _mock_top_documents_for_collections( + respx_mock, + { + "C1": [ + {"path": "a", "score": 0.8, "file_url": "ua", "metadata": {}}, + {"path": "b", "score": 0.6, "file_url": "ub", "metadata": {}}, + ], + "C2": [ + {"path": "b", "score": 0.95, "file_url": "ub2", "metadata": {}}, + ], + }, + ) + + results = await asearch_documents( + async_client, + query="q", + collections=["C1", "C2"], + k=2, + diversify=True, + ) + + assert [r.path for r in results] == ["b", "a"] + +