Skip to content

Commit ae72a4a

Browse files
committed
add helper functions to detect bad responses
Files copied from #11. Specifically: 49f9a9d
1 parent cd6f2f8 commit ae72a4a

File tree

3 files changed

+486
-0
lines changed

3 files changed

+486
-0
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
"""
2+
This module provides validation functions for evaluating LLM responses and determining if they should be replaced with Codex-generated alternatives.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
from typing import TYPE_CHECKING, Callable, Optional, Sequence, TypedDict, Union, cast
8+
9+
from cleanlab_codex.utils.errors import MissingDependencyError
10+
from cleanlab_codex.utils.prompt import default_format_prompt
11+
12+
if TYPE_CHECKING:
13+
try:
14+
from cleanlab_studio.studio.trustworthy_language_model import TLM # type: ignore
15+
except ImportError:
16+
from typing import Any, Dict, Protocol, Sequence
17+
18+
class _TLMProtocol(Protocol):
19+
def get_trustworthiness_score(
20+
self,
21+
prompt: Union[str, Sequence[str]],
22+
response: Union[str, Sequence[str]],
23+
**kwargs: Any,
24+
) -> Dict[str, Any]: ...
25+
26+
def prompt(
27+
self,
28+
prompt: Union[str, Sequence[str]],
29+
/,
30+
**kwargs: Any,
31+
) -> Dict[str, Any]: ...
32+
33+
TLM = _TLMProtocol
34+
35+
36+
DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question."
37+
DEFAULT_PARTIAL_RATIO_THRESHOLD = 70
38+
DEFAULT_TRUSTWORTHINESS_THRESHOLD = 0.5
39+
40+
41+
class BadResponseDetectionConfig(TypedDict, total=False):
42+
"""Configuration for bad response detection functions.
43+
See get_bad_response_config() for default values.
44+
45+
Attributes:
46+
fallback_answer: Known unhelpful response to compare against
47+
partial_ratio_threshold: Similarity threshold (0-100). Higher values require more similarity
48+
trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses
49+
format_prompt: Function to format (query, context) into a prompt string
50+
unhelpfulness_confidence_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification
51+
tlm: TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks)
52+
"""
53+
54+
# Fallback check config
55+
fallback_answer: str
56+
partial_ratio_threshold: int
57+
58+
# Untrustworthy check config
59+
trustworthiness_threshold: float
60+
format_prompt: Callable[[str, str], str]
61+
62+
# Unhelpful check config
63+
unhelpfulness_confidence_threshold: Optional[float]
64+
65+
# Shared config (for untrustworthiness and unhelpfulness checks)
66+
tlm: Optional[TLM]
67+
68+
69+
def get_bad_response_config() -> BadResponseDetectionConfig:
70+
"""Get the default configuration for bad response detection functions.
71+
72+
Returns:
73+
BadResponseDetectionConfig: Default configuration for bad response detection functions
74+
"""
75+
return {
76+
"fallback_answer": DEFAULT_FALLBACK_ANSWER,
77+
"partial_ratio_threshold": DEFAULT_PARTIAL_RATIO_THRESHOLD,
78+
"trustworthiness_threshold": DEFAULT_TRUSTWORTHINESS_THRESHOLD,
79+
"format_prompt": default_format_prompt,
80+
"unhelpfulness_confidence_threshold": None,
81+
"tlm": None,
82+
}
83+
84+
85+
def is_bad_response(
86+
response: str,
87+
*,
88+
context: Optional[str] = None,
89+
query: Optional[str] = None,
90+
config: Optional[BadResponseDetectionConfig] = None,
91+
) -> bool:
92+
"""Run a series of checks to determine if a response is bad.
93+
94+
If any check detects an issue (i.e. fails), the function returns True, indicating the response is bad.
95+
96+
This function runs three possible validation checks:
97+
1. **Fallback check**: Detects if response is too similar to a known fallback answer.
98+
2. **Untrustworthy check**: Assesses response trustworthiness based on the given context and query.
99+
3. **Unhelpful check**: Predicts if the response adequately answers the query or not, in a useful way.
100+
101+
Note:
102+
Each validation check runs conditionally based on whether the required arguments are provided.
103+
As soon as any validation check fails, the function returns True.
104+
105+
Args:
106+
response: The response to check.
107+
context: Optional context/documents used for answering. Required for untrustworthy check.
108+
query: Optional user question. Required for untrustworthy and unhelpful checks.
109+
config: Optional, typed dictionary of configuration parameters. See <_BadReponseConfig> for details.
110+
111+
Returns:
112+
bool: True if any validation check fails, False if all pass.
113+
"""
114+
default_cfg = get_bad_response_config()
115+
cfg: BadResponseDetectionConfig
116+
cfg = {**default_cfg, **(config or {})}
117+
118+
validation_checks: list[Callable[[], bool]] = []
119+
120+
# All required inputs are available for checking fallback responses
121+
validation_checks.append(
122+
lambda: is_fallback_response(
123+
response,
124+
cfg["fallback_answer"],
125+
threshold=cfg["partial_ratio_threshold"],
126+
)
127+
)
128+
129+
can_run_untrustworthy_check = query is not None and context is not None and cfg["tlm"] is not None
130+
if can_run_untrustworthy_check:
131+
# The if condition guarantees these are not None
132+
validation_checks.append(
133+
lambda: is_untrustworthy_response(
134+
response=response,
135+
context=cast(str, context),
136+
query=cast(str, query),
137+
tlm=cfg["tlm"],
138+
trustworthiness_threshold=cfg["trustworthiness_threshold"],
139+
format_prompt=cfg["format_prompt"],
140+
)
141+
)
142+
143+
can_run_unhelpful_check = query is not None and cfg["tlm"] is not None
144+
if can_run_unhelpful_check:
145+
validation_checks.append(
146+
lambda: is_unhelpful_response(
147+
response=response,
148+
query=cast(str, query),
149+
tlm=cfg["tlm"],
150+
trustworthiness_score_threshold=cast(float, cfg["unhelpfulness_confidence_threshold"]),
151+
)
152+
)
153+
154+
return any(check() for check in validation_checks)
155+
156+
157+
def is_fallback_response(
158+
response: str, fallback_answer: str = DEFAULT_FALLBACK_ANSWER, threshold: int = DEFAULT_PARTIAL_RATIO_THRESHOLD
159+
) -> bool:
160+
"""Check if a response is too similar to a known fallback answer.
161+
162+
Uses fuzzy string matching to compare the response against a known fallback answer.
163+
Returns True if the response is similar enough to be considered unhelpful.
164+
165+
Args:
166+
response: The response to check.
167+
fallback_answer: A known unhelpful/fallback response to compare against.
168+
threshold: Similarity threshold (0-100). Higher values require more similarity.
169+
Default 70 means responses that are 70% or more similar are considered bad.
170+
171+
Returns:
172+
bool: True if the response is too similar to the fallback answer, False otherwise
173+
"""
174+
try:
175+
from thefuzz import fuzz # type: ignore
176+
except ImportError as e:
177+
raise MissingDependencyError(
178+
import_name=e.name or "thefuzz",
179+
package_url="https://github.com/seatgeek/thefuzz",
180+
) from e
181+
182+
partial_ratio: int = fuzz.partial_ratio(fallback_answer.lower(), response.lower())
183+
return bool(partial_ratio >= threshold)
184+
185+
186+
def is_untrustworthy_response(
187+
response: str,
188+
context: str,
189+
query: str,
190+
tlm: TLM,
191+
trustworthiness_threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD,
192+
format_prompt: Callable[[str, str], str] = default_format_prompt,
193+
) -> bool:
194+
"""Check if a response is untrustworthy.
195+
196+
Uses TLM to evaluate whether a response is trustworthy given the context and query.
197+
Returns True if TLM's trustworthiness score falls below the threshold, indicating
198+
the response may be incorrect or unreliable.
199+
200+
Args:
201+
response: The response to check from the assistant
202+
context: The context information available for answering the query
203+
query: The user's question or request
204+
tlm: The TLM model to use for evaluation
205+
trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses.
206+
Default 0.5, meaning responses with scores less than 0.5 are considered untrustworthy.
207+
format_prompt: Function that takes (query, context) and returns a formatted prompt string.
208+
Users should provide their RAG app's own prompt formatting function here
209+
to match how their LLM is prompted.
210+
211+
Returns:
212+
bool: True if the response is deemed untrustworthy by TLM, False otherwise
213+
"""
214+
try:
215+
from cleanlab_studio import Studio # type: ignore[import-untyped] # noqa: F401
216+
except ImportError as e:
217+
raise MissingDependencyError(
218+
import_name=e.name or "cleanlab_studio",
219+
package_name="cleanlab-studio",
220+
package_url="https://github.com/cleanlab/cleanlab-studio",
221+
) from e
222+
223+
prompt = format_prompt(query, context)
224+
result = tlm.get_trustworthiness_score(prompt, response)
225+
score: float = result["trustworthiness_score"]
226+
return score < trustworthiness_threshold
227+
228+
229+
def is_unhelpful_response(
230+
response: str,
231+
query: str,
232+
tlm: TLM,
233+
trustworthiness_score_threshold: Optional[float] = None,
234+
) -> bool:
235+
"""Check if a response is unhelpful by asking TLM to evaluate it.
236+
237+
Uses TLM to evaluate whether a response is helpful by asking it to make a Yes/No judgment.
238+
The evaluation considers both the TLM's binary classification of helpfulness and its
239+
confidence score. Returns True only if TLM classifies the response as unhelpful AND
240+
is sufficiently confident in that assessment (if a threshold is provided).
241+
242+
Args:
243+
response: The response to check
244+
query: User query that will be used to evaluate if the response is helpful
245+
tlm: The TLM model to use for evaluation
246+
trustworthiness_score_threshold: Optional confidence threshold (0.0-1.0)
247+
If provided and the response is marked as unhelpful,
248+
the confidence score must exceed this threshold for
249+
the response to be considered truly unhelpful.
250+
251+
Returns:
252+
bool: True if TLM determines the response is unhelpful with sufficient confidence,
253+
False otherwise
254+
"""
255+
try:
256+
from cleanlab_studio import Studio # noqa: F401
257+
except ImportError as e:
258+
raise MissingDependencyError(
259+
import_name=e.name or "cleanlab_studio",
260+
package_name="cleanlab-studio",
261+
package_url="https://github.com/cleanlab/cleanlab-studio",
262+
) from e
263+
264+
# The question and expected "unhelpful" response are linked:
265+
# - When asking "is helpful?" -> "no" means unhelpful
266+
# - When asking "is unhelpful?" -> "yes" means unhelpful
267+
question = (
268+
"Does the AI Assistant Response seem unhelpful? "
269+
"Things that are not helpful include answers that:\n"
270+
"- Are not useful, incomplete, incorrect, uncertain or unclear.\n"
271+
"- Abstain or refuse to answer the question\n"
272+
"- Statements which are similar to 'I don't know', 'Sorry', or 'No information available'.\n"
273+
"- Leave the original question unresolved\n"
274+
"- Are irrelevant to the question\n"
275+
"Answer Yes/No only."
276+
)
277+
expected_unhelpful_response = "yes"
278+
279+
prompt = (
280+
"Consider the following User Query and AI Assistant Response.\n\n"
281+
f"User Query: {query}\n\n"
282+
f"AI Assistant Response: {response}\n\n"
283+
f"{question}"
284+
)
285+
286+
output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"])
287+
response_marked_unhelpful = output["response"].lower() == expected_unhelpful_response
288+
is_trustworthy = trustworthiness_score_threshold is None or (
289+
output["trustworthiness_score"] > trustworthiness_score_threshold
290+
)
291+
return response_marked_unhelpful and is_trustworthy

src/cleanlab_codex/utils/prompt.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
Utility functions for RAG (Retrieval Augmented Generation) operations.
3+
"""
4+
5+
6+
def default_format_prompt(query: str, context: str) -> str:
7+
"""Default function for formatting RAG prompts.
8+
9+
Args:
10+
query: The user's question
11+
context: The context/documents to use for answering
12+
13+
Returns:
14+
str: A formatted prompt combining the query and context
15+
"""
16+
template = (
17+
"Using only information from the following Context, answer the following Query.\n\n"
18+
"Context:\n{context}\n\n"
19+
"Query: {query}"
20+
)
21+
return template.format(context=context, query=query)

0 commit comments

Comments
 (0)