diff --git a/pyproject.toml b/pyproject.toml index c7fa840..a2b53a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,8 @@ extra-dependencies = [ "pytest", "llama-index-core", "smolagents", + "cleanlab-studio", + "thefuzz", "langchain-core", ] [tool.hatch.envs.types.scripts] @@ -54,6 +56,8 @@ allow-direct-references = true extra-dependencies = [ "llama-index-core", "smolagents; python_version >= '3.10'", + "cleanlab-studio", + "thefuzz", "langchain-core", ] diff --git a/src/cleanlab_codex/__init__.py b/src/cleanlab_codex/__init__.py index d1b8ef6..67bbc87 100644 --- a/src/cleanlab_codex/__init__.py +++ b/src/cleanlab_codex/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: MIT from cleanlab_codex.client import Client +from cleanlab_codex.codex_backup import CodexBackup from cleanlab_codex.codex_tool import CodexTool from cleanlab_codex.project import Project -__all__ = ["Client", "CodexTool", "Project"] +__all__ = ["Client", "CodexTool", "CodexBackup", "Project"] diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py new file mode 100644 index 0000000..70cc1f0 --- /dev/null +++ b/src/cleanlab_codex/codex_backup.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional, Protocol, cast + +from cleanlab_codex.response_validation import BadResponseDetectionConfig, is_bad_response + +if TYPE_CHECKING: + from cleanlab_studio.studio.trustworthy_language_model import TLM # type: ignore + + from cleanlab_codex.project import Project + + +def handle_backup_default(codex_response: str, primary_system: Any) -> None: # noqa: ARG001 + """Default implementation is a no-op.""" + return None + + +class BackupHandler(Protocol): + """Protocol defining how to handle backup responses from Codex. + + This protocol defines a callable interface for processing Codex responses that are + retrieved when the primary response system (e.g., a RAG system) fails to provide + an adequate answer. Implementations of this protocol can be used to: + + - Update the primary system's context or knowledge base + - Log Codex responses for analysis + - Trigger system improvements or retraining + - Perform any other necessary side effects + + Args: + codex_response (str): The response received from Codex + primary_system (Any): The instance of the primary RAG system that + generated the inadequate response. This allows the handler to + update or modify the primary system if needed. + + Returns: + None: The handler performs side effects but doesn't return a value + """ + + def __call__(self, codex_response: str, primary_system: Any) -> None: ... + + +class CodexBackup: + """A backup decorator that connects to a Codex project to answer questions that + cannot be adequately answered by the existing agent. + + Args: + project: The Codex project to use for backup responses + fallback_answer: The fallback answer to use if the primary system fails to provide an adequate response + backup_handler: A callback function that processes Codex's response and updates the primary RAG system. This handler is called whenever Codex provides a backup response after the primary system fails. By default, the backup handler is a no-op. + primary_system: The existing RAG system that needs to be backed up by Codex + tlm: The client for the Trustworthy Language Model, which evaluates the quality of responses from the primary system + is_bad_response_kwargs: Additional keyword arguments to pass to the is_bad_response function, for detecting inadequate responses from the primary system + """ + + DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." + + def __init__( + self, + *, + project: Project, + fallback_answer: str = DEFAULT_FALLBACK_ANSWER, + backup_handler: BackupHandler = handle_backup_default, + primary_system: Optional[Any] = None, + tlm: Optional[TLM] = None, + is_bad_response_kwargs: Optional[dict[str, Any]] = None, + ): + self._project = project + self._fallback_answer = fallback_answer + self._backup_handler = backup_handler + self._primary_system: Optional[Any] = primary_system + self._tlm = tlm + self._is_bad_response_kwargs = is_bad_response_kwargs + + @classmethod + def from_project(cls, project: Project, **kwargs: Any) -> CodexBackup: + return cls(project=project, **kwargs) + + @property + def primary_system(self) -> Any: + if self._primary_system is None: + error_message = "Primary system not set. Please set a primary system using the `add_primary_system` method." + raise ValueError(error_message) + return self._primary_system + + @primary_system.setter + def primary_system(self, primary_system: Any) -> None: + """Set the primary RAG system that will be used to generate responses.""" + self._primary_system = primary_system + + @property + def is_bad_response_kwargs(self) -> dict[str, Any]: + return self._is_bad_response_kwargs or {} + + @is_bad_response_kwargs.setter + def is_bad_response_kwargs(self, is_bad_response_kwargs: dict[str, Any]) -> None: + self._is_bad_response_kwargs = is_bad_response_kwargs + + def run( + self, + response: str, + query: str, + context: Optional[str] = None, + ) -> str: + """Check if a response is adequate and provide a backup from Codex if needed. + + Args: + primary_system: The system that generated the original response + response: The response to evaluate + query: The original query that generated the response + context: Optional context used to generate the response + + Returns: + str: Either the original response if adequate, or a backup response from Codex + """ + + _is_bad_response_kwargs = self.is_bad_response_kwargs + if not is_bad_response( + response, + query=query, + context=context, + config=cast( + BadResponseDetectionConfig, + { + "tlm": self._tlm, + "fallback_answer": self._fallback_answer, + **_is_bad_response_kwargs, + }, + ), + ): + return response + + cache_result = self._project.query(query, fallback_answer=self._fallback_answer)[0] + if not cache_result: + return response + + if self._primary_system is not None: + self._backup_handler( + codex_response=cache_result, + primary_system=self._primary_system, + ) + return cache_result diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py new file mode 100644 index 0000000..f239397 --- /dev/null +++ b/src/cleanlab_codex/response_validation.py @@ -0,0 +1,291 @@ +""" +This module provides validation functions for evaluating LLM responses and determining if they should be replaced with Codex-generated alternatives. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Optional, Sequence, TypedDict, Union, cast + +from cleanlab_codex.utils.errors import MissingDependencyError +from cleanlab_codex.utils.prompt import default_format_prompt + +if TYPE_CHECKING: + try: + from cleanlab_studio.studio.trustworthy_language_model import TLM # type: ignore + except ImportError: + from typing import Any, Dict, Protocol, Sequence + + class _TLMProtocol(Protocol): + def get_trustworthiness_score( + self, + prompt: Union[str, Sequence[str]], + response: Union[str, Sequence[str]], + **kwargs: Any, + ) -> Dict[str, Any]: ... + + def prompt( + self, + prompt: Union[str, Sequence[str]], + /, + **kwargs: Any, + ) -> Dict[str, Any]: ... + + TLM = _TLMProtocol + + +DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." +DEFAULT_PARTIAL_RATIO_THRESHOLD = 70 +DEFAULT_TRUSTWORTHINESS_THRESHOLD = 0.5 + + +class BadResponseDetectionConfig(TypedDict, total=False): + """Configuration for bad response detection functions. + See get_bad_response_config() for default values. + + Attributes: + fallback_answer: Known unhelpful response to compare against + partial_ratio_threshold: Similarity threshold (0-100). Higher values require more similarity + trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses + format_prompt: Function to format (query, context) into a prompt string + unhelpfulness_confidence_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification + tlm: TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks) + """ + + # Fallback check config + fallback_answer: str + partial_ratio_threshold: int + + # Untrustworthy check config + trustworthiness_threshold: float + format_prompt: Callable[[str, str], str] + + # Unhelpful check config + unhelpfulness_confidence_threshold: Optional[float] + + # Shared config (for untrustworthiness and unhelpfulness checks) + tlm: Optional[TLM] + + +def get_bad_response_config() -> BadResponseDetectionConfig: + """Get the default configuration for bad response detection functions. + + Returns: + BadResponseDetectionConfig: Default configuration for bad response detection functions + """ + return { + "fallback_answer": DEFAULT_FALLBACK_ANSWER, + "partial_ratio_threshold": DEFAULT_PARTIAL_RATIO_THRESHOLD, + "trustworthiness_threshold": DEFAULT_TRUSTWORTHINESS_THRESHOLD, + "format_prompt": default_format_prompt, + "unhelpfulness_confidence_threshold": None, + "tlm": None, + } + + +def is_bad_response( + response: str, + *, + context: Optional[str] = None, + query: Optional[str] = None, + config: Optional[BadResponseDetectionConfig] = None, +) -> bool: + """Run a series of checks to determine if a response is bad. + + If any check detects an issue (i.e. fails), the function returns True, indicating the response is bad. + + This function runs three possible validation checks: + 1. **Fallback check**: Detects if response is too similar to a known fallback answer. + 2. **Untrustworthy check**: Assesses response trustworthiness based on the given context and query. + 3. **Unhelpful check**: Predicts if the response adequately answers the query or not, in a useful way. + + Note: + Each validation check runs conditionally based on whether the required arguments are provided. + As soon as any validation check fails, the function returns True. + + Args: + response: The response to check. + context: Optional context/documents used for answering. Required for untrustworthy check. + query: Optional user question. Required for untrustworthy and unhelpful checks. + config: Optional, typed dictionary of configuration parameters. See <_BadReponseConfig> for details. + + Returns: + bool: True if any validation check fails, False if all pass. + """ + default_cfg = get_bad_response_config() + cfg: BadResponseDetectionConfig + cfg = {**default_cfg, **(config or {})} + + validation_checks: list[Callable[[], bool]] = [] + + # All required inputs are available for checking fallback responses + validation_checks.append( + lambda: is_fallback_response( + response, + cfg["fallback_answer"], + threshold=cfg["partial_ratio_threshold"], + ) + ) + + can_run_untrustworthy_check = query is not None and context is not None and cfg["tlm"] is not None + if can_run_untrustworthy_check: + # The if condition guarantees these are not None + validation_checks.append( + lambda: is_untrustworthy_response( + response=response, + context=cast(str, context), + query=cast(str, query), + tlm=cfg["tlm"], + trustworthiness_threshold=cfg["trustworthiness_threshold"], + format_prompt=cfg["format_prompt"], + ) + ) + + can_run_unhelpful_check = query is not None and cfg["tlm"] is not None + if can_run_unhelpful_check: + validation_checks.append( + lambda: is_unhelpful_response( + response=response, + query=cast(str, query), + tlm=cfg["tlm"], + trustworthiness_score_threshold=cast(float, cfg["unhelpfulness_confidence_threshold"]), + ) + ) + + return any(check() for check in validation_checks) + + +def is_fallback_response( + response: str, fallback_answer: str = DEFAULT_FALLBACK_ANSWER, threshold: int = DEFAULT_PARTIAL_RATIO_THRESHOLD +) -> bool: + """Check if a response is too similar to a known fallback answer. + + Uses fuzzy string matching to compare the response against a known fallback answer. + Returns True if the response is similar enough to be considered unhelpful. + + Args: + response: The response to check. + fallback_answer: A known unhelpful/fallback response to compare against. + threshold: Similarity threshold (0-100). Higher values require more similarity. + Default 70 means responses that are 70% or more similar are considered bad. + + Returns: + bool: True if the response is too similar to the fallback answer, False otherwise + """ + try: + from thefuzz import fuzz # type: ignore + except ImportError as e: + raise MissingDependencyError( + import_name=e.name or "thefuzz", + package_url="https://github.com/seatgeek/thefuzz", + ) from e + + partial_ratio: int = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) + return bool(partial_ratio >= threshold) + + +def is_untrustworthy_response( + response: str, + context: str, + query: str, + tlm: TLM, + trustworthiness_threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, + format_prompt: Callable[[str, str], str] = default_format_prompt, +) -> bool: + """Check if a response is untrustworthy. + + Uses TLM to evaluate whether a response is trustworthy given the context and query. + Returns True if TLM's trustworthiness score falls below the threshold, indicating + the response may be incorrect or unreliable. + + Args: + response: The response to check from the assistant + context: The context information available for answering the query + query: The user's question or request + tlm: The TLM model to use for evaluation + trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. + Default 0.5, meaning responses with scores less than 0.5 are considered untrustworthy. + format_prompt: Function that takes (query, context) and returns a formatted prompt string. + Users should provide their RAG app's own prompt formatting function here + to match how their LLM is prompted. + + Returns: + bool: True if the response is deemed untrustworthy by TLM, False otherwise + """ + try: + from cleanlab_studio import Studio # type: ignore[import-untyped] # noqa: F401 + except ImportError as e: + raise MissingDependencyError( + import_name=e.name or "cleanlab_studio", + package_name="cleanlab-studio", + package_url="https://github.com/cleanlab/cleanlab-studio", + ) from e + + prompt = format_prompt(query, context) + result = tlm.get_trustworthiness_score(prompt, response) + score: float = result["trustworthiness_score"] + return score < trustworthiness_threshold + + +def is_unhelpful_response( + response: str, + query: str, + tlm: TLM, + trustworthiness_score_threshold: Optional[float] = None, +) -> bool: + """Check if a response is unhelpful by asking TLM to evaluate it. + + Uses TLM to evaluate whether a response is helpful by asking it to make a Yes/No judgment. + The evaluation considers both the TLM's binary classification of helpfulness and its + confidence score. Returns True only if TLM classifies the response as unhelpful AND + is sufficiently confident in that assessment (if a threshold is provided). + + Args: + response: The response to check + query: User query that will be used to evaluate if the response is helpful + tlm: The TLM model to use for evaluation + trustworthiness_score_threshold: Optional confidence threshold (0.0-1.0) + If provided and the response is marked as unhelpful, + the confidence score must exceed this threshold for + the response to be considered truly unhelpful. + + Returns: + bool: True if TLM determines the response is unhelpful with sufficient confidence, + False otherwise + """ + try: + from cleanlab_studio import Studio # noqa: F401 + except ImportError as e: + raise MissingDependencyError( + import_name=e.name or "cleanlab_studio", + package_name="cleanlab-studio", + package_url="https://github.com/cleanlab/cleanlab-studio", + ) from e + + # The question and expected "unhelpful" response are linked: + # - When asking "is helpful?" -> "no" means unhelpful + # - When asking "is unhelpful?" -> "yes" means unhelpful + question = ( + "Does the AI Assistant Response seem unhelpful? " + "Things that are not helpful include answers that:\n" + "- Are not useful, incomplete, incorrect, uncertain or unclear.\n" + "- Abstain or refuse to answer the question\n" + "- Statements which are similar to 'I don't know', 'Sorry', or 'No information available'.\n" + "- Leave the original question unresolved\n" + "- Are irrelevant to the question\n" + "Answer Yes/No only." + ) + expected_unhelpful_response = "yes" + + prompt = ( + "Consider the following User Query and AI Assistant Response.\n\n" + f"User Query: {query}\n\n" + f"AI Assistant Response: {response}\n\n" + f"{question}" + ) + + output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) + response_marked_unhelpful = output["response"].lower() == expected_unhelpful_response + is_trustworthy = trustworthiness_score_threshold is None or ( + output["trustworthiness_score"] > trustworthiness_score_threshold + ) + return response_marked_unhelpful and is_trustworthy diff --git a/src/cleanlab_codex/utils/__init__.py b/src/cleanlab_codex/utils/__init__.py index 5eefc17..86d8c51 100644 --- a/src/cleanlab_codex/utils/__init__.py +++ b/src/cleanlab_codex/utils/__init__.py @@ -5,6 +5,7 @@ from cleanlab_codex.utils.openai import Function as OpenAIFunction from cleanlab_codex.utils.openai import Tool as OpenAITool from cleanlab_codex.utils.openai import format_as_openai_tool +from cleanlab_codex.utils.prompt import default_format_prompt __all__ = [ "FunctionParameters", @@ -14,4 +15,5 @@ "AWSToolSpec", "format_as_openai_tool", "format_as_aws_converse_tool", + "default_format_prompt", ] diff --git a/src/cleanlab_codex/utils/prompt.py b/src/cleanlab_codex/utils/prompt.py new file mode 100644 index 0000000..c04fc71 --- /dev/null +++ b/src/cleanlab_codex/utils/prompt.py @@ -0,0 +1,21 @@ +""" +Utility functions for RAG (Retrieval Augmented Generation) operations. +""" + + +def default_format_prompt(query: str, context: str) -> str: + """Default function for formatting RAG prompts. + + Args: + query: The user's question + context: The context/documents to use for answering + + Returns: + str: A formatted prompt combining the query and context + """ + template = ( + "Using only information from the following Context, answer the following Query.\n\n" + "Context:\n{context}\n\n" + "Query: {query}" + ) + return template.format(context=context, query=query) diff --git a/tests/test_codex_backup.py b/tests/test_codex_backup.py new file mode 100644 index 0000000..e8e9001 --- /dev/null +++ b/tests/test_codex_backup.py @@ -0,0 +1,71 @@ +from unittest.mock import MagicMock + +from cleanlab_codex.codex_backup import CodexBackup + +MOCK_BACKUP_RESPONSE = "This is a test response" +FALLBACK_MESSAGE = "Based on the available information, I cannot provide a complete answer to this question." +TEST_MESSAGE = "Hello, world!" + + +def test_codex_backup() -> None: + # Create a mock project directly + mock_project = MagicMock() + mock_project.query.return_value = (MOCK_BACKUP_RESPONSE,) + + class MockApp: + def chat(self, user_message: str) -> str: + # Just echo the user message + return user_message + + app = MockApp() + codex_backup = CodexBackup.from_project(mock_project) + + # Echo works well + query = TEST_MESSAGE + response = app.chat(query) + assert response == query + + # Backup works well for fallback responses + query = FALLBACK_MESSAGE + response = app.chat(query) + assert response == query + response = codex_backup.run(response, query=query) + assert response == MOCK_BACKUP_RESPONSE, f"Response was {response}" + + +def test_backup_handler() -> None: + mock_project = MagicMock() + mock_project.query.return_value = (MOCK_BACKUP_RESPONSE,) + + mock_handler = MagicMock() + mock_handler.return_value = None + + class MockApp: + def chat(self, user_message: str) -> str: + # Just echo the user message + return user_message + + app = MockApp() + codex_backup = CodexBackup.from_project(mock_project, primary_system=app, backup_handler=mock_handler) + + query = TEST_MESSAGE + response = app.chat(query) + assert response == query + + response = codex_backup.run(response, query=query) + assert response == query, f"Response was {response}" + + # Handler should not be called for good responses + assert mock_handler.call_count == 0 + + query = FALLBACK_MESSAGE + response = app.chat(query) + assert response == query + response = codex_backup.run(response, query=query) + assert response == MOCK_BACKUP_RESPONSE, f"Response was {response}" + + # Handler should be called for bad responses + assert mock_handler.call_count == 1 + # The MockApp is the second argument to the handler, i.e. it has the necessary context + # to handle the new response + assert mock_handler.call_args.kwargs["primary_system"] == app diff --git a/tests/test_response_validation.py b/tests/test_response_validation.py new file mode 100644 index 0000000..d10e661 --- /dev/null +++ b/tests/test_response_validation.py @@ -0,0 +1,174 @@ +"""Unit tests for validation module functions.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest + +from cleanlab_codex.response_validation import ( + is_bad_response, + is_fallback_response, + is_unhelpful_response, + is_untrustworthy_response, +) + +# Mock responses for testing +GOOD_RESPONSE = "This is a helpful and specific response that answers the question completely." +BAD_RESPONSE = "Based on the available information, I cannot provide a complete answer." +QUERY = "What is the capital of France?" +CONTEXT = "Paris is the capital and largest city of France." + + +@pytest.fixture +def mock_tlm() -> Mock: + """Create a mock TLM instance.""" + mock = Mock() + # Configure default return values + mock.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.8} + mock.prompt.return_value = {"response": "No", "trustworthiness_score": 0.9} + return mock + + +@pytest.mark.parametrize( + ("response", "threshold", "fallback_answer", "expected"), + [ + # Test threshold variations + (GOOD_RESPONSE, 30, None, True), + (GOOD_RESPONSE, 55, None, False), + # Test default behavior (BAD_RESPONSE should be flagged) + (BAD_RESPONSE, None, None, True), + # Test default behavior for different response (GOOD_RESPONSE should not be flagged) + (GOOD_RESPONSE, None, None, False), + # Test custom fallback answer + (GOOD_RESPONSE, 80, "This is an unhelpful response", False), + ], +) +def test_is_fallback_response( + response: str, + threshold: float | None, + fallback_answer: str | None, + *, + expected: bool, +) -> None: + """Test fallback response detection.""" + kwargs: dict[str, float | str] = {} + if threshold is not None: + kwargs["threshold"] = threshold + if fallback_answer is not None: + kwargs["fallback_answer"] = fallback_answer + + assert is_fallback_response(response, **kwargs) is expected # type: ignore + + +def test_is_untrustworthy_response(mock_tlm: Mock) -> None: + """Test untrustworthy response detection.""" + # Test trustworthy response + mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.8} + assert is_untrustworthy_response(GOOD_RESPONSE, CONTEXT, QUERY, mock_tlm, trustworthiness_threshold=0.5) is False + + # Test untrustworthy response + mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.3} + assert is_untrustworthy_response(BAD_RESPONSE, CONTEXT, QUERY, mock_tlm, trustworthiness_threshold=0.5) is True + + +@pytest.mark.parametrize( + ("response", "tlm_response", "tlm_score", "threshold", "expected"), + [ + # Test helpful response + (GOOD_RESPONSE, "No", 0.9, 0.5, False), + # Test unhelpful response + (BAD_RESPONSE, "Yes", 0.9, 0.5, True), + # Test unhelpful response but low trustworthiness score + (BAD_RESPONSE, "Yes", 0.3, 0.5, False), + # Test without threshold - Yes prediction + (BAD_RESPONSE, "Yes", 0.3, None, True), + (GOOD_RESPONSE, "Yes", 0.3, None, True), + # Test without threshold - No prediction + (BAD_RESPONSE, "No", 0.3, None, False), + (GOOD_RESPONSE, "No", 0.3, None, False), + ], +) +def test_is_unhelpful_response( + mock_tlm: Mock, + response: str, + tlm_response: str, + tlm_score: float, + threshold: float | None, + *, + expected: bool, +) -> None: + """Test unhelpful response detection.""" + mock_tlm.prompt.return_value = {"response": tlm_response, "trustworthiness_score": tlm_score} + assert is_unhelpful_response(response, QUERY, mock_tlm, trustworthiness_score_threshold=threshold) is expected + + +@pytest.mark.parametrize( + ("response", "trustworthiness_score", "prompt_response", "prompt_score", "expected"), + [ + # Good response passes all checks + (GOOD_RESPONSE, 0.8, "No", 0.9, False), + # Bad response fails at least one check + (BAD_RESPONSE, 0.3, "Yes", 0.9, True), + ], +) +def test_is_bad_response( + mock_tlm: Mock, + response: str, + trustworthiness_score: float, + prompt_response: str, + prompt_score: float, + *, + expected: bool, +) -> None: + """Test the main is_bad_response function.""" + mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": trustworthiness_score} + mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} + + assert ( + is_bad_response( + response, + context=CONTEXT, + query=QUERY, + config={"tlm": mock_tlm}, + ) + is expected + ) + + +@pytest.mark.parametrize( + ("response", "fuzz_ratio", "prompt_response", "prompt_score", "query", "tlm", "expected"), + [ + # Test with only fallback check (no context/query/tlm) + (BAD_RESPONSE, 90, None, None, None, None, True), + # Test with fallback and unhelpful checks (no context) + (GOOD_RESPONSE, 30, "No", 0.9, QUERY, "mock_tlm", False), + ], +) +def test_is_bad_response_partial_inputs( + mock_tlm: Mock, + response: str, + fuzz_ratio: int, + prompt_response: str, + prompt_score: float, + query: str, + tlm: Mock, + *, + expected: bool, +) -> None: + """Test is_bad_response with partial inputs (some checks disabled).""" + mock_fuzz = Mock() + mock_fuzz.partial_ratio.return_value = fuzz_ratio + with patch.dict("sys.modules", {"thefuzz": Mock(fuzz=mock_fuzz)}): + if prompt_response is not None: + mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} + tlm = mock_tlm + + assert ( + is_bad_response( + response, + query=query, + config={"tlm": tlm}, + ) + is expected + )