Skip to content

Commit b3fa2c0

Browse files
Merge pull request #228 from oracle-samples/218-testbed
218 testbed
2 parents bc01221 + febc893 commit b3fa2c0

File tree

7 files changed

+97
-7
lines changed

7 files changed

+97
-7
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,4 @@ spring_ai/create_user.sql
6666
spring_ai/drop.sql
6767
src/client/spring_ai/target/classes/*
6868
api_server_key
69+
.env

src/.streamlit/config.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
disableWidgetStateDuplicationWarning = true
55

66
[theme]
7-
font = "Source Sans Pro"
7+
font = "sans-serif-pro"
8+
headingFont = "sans-serif"
9+
codeFont = "monospace"
10+
11+
[theme.sidebar]
12+
font = "sans-serif"
13+
headingFont = "sans-serif"
14+
codeFont = "monospace"
815

916
[browser]
1017
gatherUsageStats = false

src/client/content/testbed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def create_gauge(value):
127127
st.dataframe(full_report, hide_index=True)
128128

129129
# Download Button
130-
download_file("Download Report", report["html_report"], "evaluation_report.html", "text/html")
130+
# download_file("Download Report", report["html_report"], "evaluation_report.html", "text/html") #CDB
131131

132132

133133
@st.cache_data

src/launch_server.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
Copyright (c) 2024, 2025, Oracle and/or its affiliates.
33
Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
44
"""
5-
# spell-checker:ignore fastapi, laddr, checkpointer, langgraph, litellm, noauth, apiserver, configfile, selectai
5+
# spell-checker:ignore fastapi laddr checkpointer langgraph litellm
6+
# spell-checker:ignore noauth apiserver configfile selectai giskard ollama llms
67
# pylint: disable=redefined-outer-name,wrong-import-position
78

89
import os
@@ -20,6 +21,9 @@
2021
if "TNS_ADMIN" not in os.environ:
2122
os.environ["TNS_ADMIN"] = os.path.join(app_home, "tns_admin")
2223

24+
# Patch litellm for Giskard/Ollama issue
25+
import server.patches.litellm_patch # pylint: disable=unused-import
26+
2327
import argparse
2428
import queue
2529
import secrets
@@ -148,9 +152,9 @@ def verify_key(
148152

149153
def register_endpoints(noauth: APIRouter, auth: APIRouter):
150154
"""Register API Endpoints - Imports to avoid bootstrapping before config file read
151-
New endpoints need to be registered in server.api.v1.__init__.py
155+
New endpoints need to be registered in server.api.v1.__init__.py
152156
"""
153-
import server.api.v1 as api_v1 # pylint: disable=import-outside-toplevel
157+
import server.api.v1 as api_v1 # pylint: disable=import-outside-toplevel
154158

155159
# No-Authentication (probes only)
156160
noauth.include_router(api_v1.probes.noauth, prefix="/v1", tags=["Probes"])
@@ -166,6 +170,7 @@ def register_endpoints(noauth: APIRouter, auth: APIRouter):
166170
auth.include_router(api_v1.settings.auth, prefix="/v1/settings", tags=["Tools - Settings"])
167171
auth.include_router(api_v1.testbed.auth, prefix="/v1/testbed", tags=["Tools - Testbed"])
168172

173+
169174
#############################################################################
170175
# APP FACTORY
171176
#############################################################################

src/server/api/utils/testbed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ def clean(orig_html):
334334
"report": full_report.to_dict(),
335335
"correct_by_topic": by_topic.to_dict(),
336336
"failures": failures.to_dict(),
337-
"html_report": clean(html_report),
337+
#"html_report": clean(html_report), #CDB
338+
"html_report": '<html><body></body></html>'
338339
}
339340
logger.debug("Evaluation Results: %s", evaluation_results)
340341
evaluation = schema.EvaluationReport(**evaluation_results)

src/server/api/v1/testbed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ def get_answer(question: str):
225225
oci_config = oci.get_oci(client)
226226
judge_client = core_models.get_client({"model": judge}, oci_config, True)
227227
try:
228-
report = evaluate(get_answer, testset=loaded_testset, llm_client=judge_client, metrics=[correctness_metric])
228+
#report = evaluate(get_answer, testset=loaded_testset, llm_client=judge_client, metrics=[correctness_metric]) #CDB
229+
report = evaluate(get_answer, testset=loaded_testset, llm_client=judge_client, metrics=None) #CDB
230+
229231
except KeyError as ex:
230232
if str(ex) == "'correctness'":
231233
raise HTTPException(status_code=500, detail="Unable to determine the correctness; please retry.") from ex
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
Copyright (c) 2024, 2025, Oracle and/or its affiliates.
3+
Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
4+
"""
5+
# spell-checker:ignore litellm giskard ollama llms
6+
# pylint: disable=unused-argument,protected-access
7+
8+
from typing import TYPE_CHECKING, List, Optional, Any
9+
import time
10+
import litellm
11+
from litellm.llms.ollama.completion.transformation import OllamaConfig
12+
from litellm.types.llms.openai import AllMessageValues
13+
from litellm.types.utils import ModelResponse
14+
from httpx._models import Response
15+
16+
import common.logging_config as logging_config
17+
18+
logger = logging_config.logging.getLogger("patches.litellm_patch")
19+
20+
# Only patch if not already patched
21+
if not getattr(OllamaConfig.transform_response, "_is_custom_patch", False):
22+
if TYPE_CHECKING:
23+
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
24+
25+
LiteLLMLoggingObj = _LiteLLMLoggingObj
26+
else:
27+
LiteLLMLoggingObj = Any
28+
29+
def custom_transform_response(
30+
self,
31+
model: str,
32+
raw_response: Response,
33+
model_response: ModelResponse,
34+
logging_obj: LiteLLMLoggingObj,
35+
request_data: dict,
36+
messages: List[AllMessageValues],
37+
optional_params: dict,
38+
litellm_params: dict,
39+
encoding: str,
40+
api_key: Optional[str] = None,
41+
json_mode: Optional[bool] = None,
42+
):
43+
"""Custom transform response from .venv/lib/python3.11/site-packages/litellm/llms/ollama/completion/transformation.py"""
44+
logger.info("Custom transform_response is running")
45+
response_json = raw_response.json()
46+
47+
model_response.choices[0].finish_reason = "stop"
48+
model_response.choices[0].message.content = response_json["response"]
49+
50+
_prompt = request_data.get("prompt", "")
51+
prompt_tokens = response_json.get(
52+
"prompt_eval_count",
53+
len(encoding.encode(_prompt, disallowed_special=())),
54+
)
55+
completion_tokens = response_json.get("eval_count", len(response_json.get("message", {}).get("content", "")))
56+
57+
setattr(
58+
model_response,
59+
"usage",
60+
litellm.Usage(
61+
prompt_tokens=prompt_tokens,
62+
completion_tokens=completion_tokens,
63+
total_tokens=prompt_tokens + completion_tokens,
64+
),
65+
)
66+
model_response.created = int(time.time())
67+
model_response.model = "ollama/" + model
68+
return model_response
69+
70+
# Mark it to avoid double patching
71+
custom_transform_response._is_custom_patch = True
72+
73+
# Patch it
74+
OllamaConfig.transform_response = custom_transform_response

0 commit comments

Comments
 (0)