Skip to content

Commit 653dc11

Browse files
Merge pull request #230 from oracle-samples/spring_ai_fix
fix models for SpringAI
2 parents 23542b3 + f248614 commit 653dc11

File tree

2 files changed

+31
-37
lines changed

2 files changed

+31
-37
lines changed

src/client/content/config/models.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55
This script initializes a web interface for model configuration using Streamlit (`st`).
66
77
Session States Set:
8-
- ll_model_config: Stores all Language Model Configuration
9-
- embed_model_config: Stores all Embedding Model Configuration
10-
11-
- ll_model_enabled: Stores all Enabled Language Models
12-
- embed_model_enabled: Stores all Enabled Embedding Models
8+
- model_configs: Stores all Model Configurations
139
"""
1410
# spell-checker:ignore selectbox
1511

src/client/content/config/settings.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
# Utilities
2727
import client.utils.api_call as api_call
28+
import client.utils.st_common as st_common
2829
from client.utils.st_footer import remove_footer
2930

3031
import common.logging_config as logging_config
@@ -47,6 +48,7 @@ def get_settings(include_sensitive: bool = False):
4748
)
4849
return settings
4950

51+
5052
def save_settings(settings):
5153
"""Save Settings after changing client"""
5254

@@ -61,12 +63,7 @@ def save_settings(settings):
6163

6264
def compare_settings(current, uploaded, path=""):
6365
"""Compare current settings with uploaded settings."""
64-
differences = {
65-
"Value Mismatch": {},
66-
"Missing in Uploaded": {},
67-
"Missing in Current": {},
68-
"Override on Upload": {}
69-
}
66+
differences = {"Value Mismatch": {}, "Missing in Uploaded": {}, "Missing in Current": {}, "Override on Upload": {}}
7067

7168
sensitive_keys = {"api_key", "password", "wallet_password"}
7269

@@ -98,10 +95,7 @@ def compare_settings(current, uploaded, path=""):
9895
# Both present — compare
9996
if is_sensitive:
10097
if current[key] != uploaded[key]:
101-
differences["Value Mismatch"][new_path] = {
102-
"current": current[key],
103-
"uploaded": uploaded[key]
104-
}
98+
differences["Value Mismatch"][new_path] = {"current": current[key], "uploaded": uploaded[key]}
10599
else:
106100
child_diff = compare_settings(current[key], uploaded[key], new_path)
107101
for diff_type, diff_dict in differences.items():
@@ -123,10 +117,7 @@ def compare_settings(current, uploaded, path=""):
123117

124118
else:
125119
if current != uploaded:
126-
differences["Value Mismatch"][path] = {
127-
"current": current,
128-
"uploaded": uploaded
129-
}
120+
differences["Value Mismatch"][path] = {"current": current, "uploaded": uploaded}
130121

131122
return differences
132123

@@ -153,8 +144,8 @@ def spring_ai_conf_check(ll_model, embed_model) -> str:
153144
if ll_model is None or embed_model is None:
154145
return "hybrid"
155146

156-
ll_api = state.ll_model_enabled[ll_model]["api"]
157-
embed_api = state.embed_model_enabled[embed_model]["api"]
147+
ll_api = ll_model["api"]
148+
embed_api = embed_model["api"]
158149

159150
if "OpenAI" in ll_api and "OpenAI" in embed_api:
160151
return "openai"
@@ -163,7 +154,8 @@ def spring_ai_conf_check(ll_model, embed_model) -> str:
163154

164155
return "hybrid"
165156

166-
def spring_ai_obaas(src_dir, file_name, provider, ll_model):
157+
158+
def spring_ai_obaas(src_dir, file_name, provider, ll_config, embed_config):
167159
"""Get the users CTX Prompt"""
168160
ctx_prompt = next(
169161
item["prompt"]
@@ -174,12 +166,14 @@ def spring_ai_obaas(src_dir, file_name, provider, ll_model):
174166
with open(src_dir / "templates" / file_name, "r", encoding="utf-8") as template:
175167
template_content = template.read()
176168

169+
database_lookup = st_common.state_configs_lookup("database_configs", "name")
170+
177171
formatted_content = template_content.format(
178172
provider=provider,
179173
ctx_prompt=f"{ctx_prompt}",
180-
ll_model=state.client_settings["ll_model"] | state.ll_model_enabled[ll_model],
181-
vector_search=state.client_settings["vector_search"],
182-
database_config=state.database_config[state.client_settings["database"]["alias"]],
174+
ll_model=ll_config,
175+
vector_search=embed_config,
176+
database_config=database_lookup[state.client_settings["database"]["alias"]],
183177
)
184178

185179
if file_name.endswith(".yaml"):
@@ -188,9 +182,9 @@ def spring_ai_obaas(src_dir, file_name, provider, ll_model):
188182
formatted_content = template_content.format(
189183
provider=provider,
190184
ctx_prompt=ctx_prompt,
191-
ll_model=state.client_settings["ll_model"] | state.ll_model_enabled[ll_model],
192-
vector_search=state.client_settings["vector_search"],
193-
database_config=state.database_config[state.client_settings["database"]["alias"]],
185+
ll_model=ll_config,
186+
vector_search=embed_config,
187+
database_config=database_lookup[state.client_settings["database"]["alias"]],
194188
)
195189

196190
yaml_data = yaml.safe_load(formatted_content)
@@ -203,7 +197,7 @@ def spring_ai_obaas(src_dir, file_name, provider, ll_model):
203197
return formatted_content
204198

205199

206-
def spring_ai_zip(provider, ll_model):
200+
def spring_ai_zip(provider, ll_config, embed_config):
207201
"""Create SpringAI Zip File"""
208202
# Source directory that you want to copy
209203
files = ["mvnw", "mvnw.cmd", "pom.xml", "README.md"]
@@ -227,8 +221,8 @@ def spring_ai_zip(provider, ll_model):
227221

228222
arc_name = os.path.relpath(file_path, dst_dir) # Make the path relative
229223
zip_file.write(file_path, arc_name)
230-
env_content = spring_ai_obaas(src_dir, "start.sh", provider, ll_model)
231-
yaml_content = spring_ai_obaas(src_dir, "obaas.yaml", provider, ll_model)
224+
env_content = spring_ai_obaas(src_dir, "start.sh", provider, ll_config, embed_config)
225+
yaml_content = spring_ai_obaas(src_dir, "obaas.yaml", provider, ll_config, embed_config)
232226
zip_file.writestr("start.sh", env_content.encode("utf-8"))
233227
zip_file.writestr("src/main/resources/application-obaas.yml", yaml_content.encode("utf-8"))
234228
zip_buffer.seek(0)
@@ -291,21 +285,25 @@ def main():
291285
st.info("Please upload a Settings file.")
292286

293287
st.header("SpringAI Settings", divider="red")
294-
ll_model = state.client_settings["ll_model"]["model"]
295-
embed_model = state.client_settings["vector_search"]["model"]
296-
spring_ai_conf = spring_ai_conf_check(ll_model, embed_model)
288+
# Merge the User Settings into the Model Config
289+
model_lookup = st_common.state_configs_lookup("model_configs", "id")
290+
ll_config = model_lookup[state.client_settings["ll_model"]["model"]] | state.client_settings["ll_model"]
291+
embed_config = (
292+
model_lookup[state.client_settings["vector_search"]["model"]] | state.client_settings["vector_search"]
293+
)
294+
spring_ai_conf = spring_ai_conf_check(ll_config, embed_config)
297295

298296
if spring_ai_conf == "hybrid":
299297
st.markdown(f"""
300298
The current configuration combination of embedding and language models
301299
is currently **not supported** for SpringAI.
302-
- Language Model: **{ll_model}**
303-
- Embedding Model: **{embed_model}**
300+
- Language Model: **{ll_config["model"]}**
301+
- Embedding Model: **{embed_config["model"]}**
304302
""")
305303
else:
306304
st.download_button(
307305
label="Download SpringAI",
308-
data=spring_ai_zip(spring_ai_conf, ll_model), # Generate zip on the fly
306+
data=spring_ai_zip(spring_ai_conf, ll_config, embed_config), # Generate zip on the fly
309307
file_name="spring_ai.zip", # Zip file name
310308
mime="application/zip", # Mime type for zip file
311309
disabled=spring_ai_conf == "hybrid",

0 commit comments

Comments
 (0)