Skip to content

Commit 5d711d4

Browse files
fix: Update watsonx.ai provider to use LiteLLM mixin and list all models (#3674)
# What does this PR do? - The watsonx.ai provider now uses the LiteLLM mixin instead of using IBM's library, which does not seem to be working (see #3165 for context). - The watsonx.ai provider now lists all the models available by calling the watsonx.ai server instead of having a hard coded list of known models. (That list gets out of date quickly) - An edge case in [llama_stack/core/routers/inference.py](https://github.com/llamastack/llama-stack/pull/3674/files#diff-a34bc966ed9befd9f13d4883c23705dff49be0ad6211c850438cdda6113f3455) is addressed that was causing my manual tests to fail. - Fixes `b64_encode_openai_embeddings_response` which was trying to enumerate over a dictionary and then reference elements of the dictionary using .field instead of ["field"]. That method is called by the LiteLLM mixin for embedding models, so it is needed to get the watsonx.ai embedding models to work. - A unit test along the lines of the one in #3348 is added. A more comprehensive plan for automatically testing the end-to-end functionality for inference providers would be a good idea, but is out of scope for this PR. - Updates to the watsonx distribution. Some were in response to the switch to LiteLLM (e.g., updating the Python packages needed). Others seem to be things that were already broken that I found along the way (e.g., a reference to a watsonx specific doc template that doesn't seem to exist). Closes #3165 Also it is related to a line-item in #3387 but doesn't really address that goal (because it uses the LiteLLM mixin, not the OpenAI one). I tried the OpenAI one and it doesn't work with watsonx.ai, presumably because the watsonx.ai service is not OpenAI compatible. It works with LiteLLM because LiteLLM has a provider implementation for watsonx.ai. ## Test Plan The test script below goes back and forth between the OpenAI and watsonx providers. The idea is that the OpenAI provider shows how it should work and then the watsonx provider output shows that it is also working with watsonx. Note that the result from the MCP test is not as good (the Llama 3.3 70b model does not choose tools as wisely as gpt-4o), but it is still working and providing a valid response. For more details on setup and the MCP server being used for testing, see [the AI Alliance sample notebook](https://github.com/The-AI-Alliance/llama-stack-examples/blob/main/notebooks/01-responses/) that these examples are drawn from. ```python #!/usr/bin/env python3 import json from llama_stack_client import LlamaStackClient from litellm import completion import http.client def print_response(response): """Print response in a nicely formatted way""" print(f"ID: {response.id}") print(f"Status: {response.status}") print(f"Model: {response.model}") print(f"Created at: {response.created_at}") print(f"Output items: {len(response.output)}") for i, output_item in enumerate(response.output): if len(response.output) > 1: print(f"\n--- Output Item {i+1} ---") print(f"Output type: {output_item.type}") if output_item.type in ("text", "message"): print(f"Response content: {output_item.content[0].text}") elif output_item.type == "file_search_call": print(f" Tool Call ID: {output_item.id}") print(f" Tool Status: {output_item.status}") # 'queries' is a list, so we join it for clean printing print(f" Queries: {', '.join(output_item.queries)}") # Display results if they exist, otherwise note they are empty print(f" Results: {output_item.results if output_item.results else 'None'}") elif output_item.type == "mcp_list_tools": print_mcp_list_tools(output_item) elif output_item.type == "mcp_call": print_mcp_call(output_item) else: print(f"Response content: {output_item.content}") def print_mcp_call(mcp_call): """Print MCP call in a nicely formatted way""" print(f"\n🛠️ MCP Tool Call: {mcp_call.name}") print(f" Server: {mcp_call.server_label}") print(f" ID: {mcp_call.id}") print(f" Arguments: {mcp_call.arguments}") if mcp_call.error: print("Error: {mcp_call.error}") elif mcp_call.output: print("Output:") # Try to format JSON output nicely try: parsed_output = json.loads(mcp_call.output) print(json.dumps(parsed_output, indent=4)) except: # If not valid JSON, print as-is print(f" {mcp_call.output}") else: print(" ⏳ No output yet") def print_mcp_list_tools(mcp_list_tools): """Print MCP list tools in a nicely formatted way""" print(f"\n🔧 MCP Server: {mcp_list_tools.server_label}") print(f" ID: {mcp_list_tools.id}") print(f" Available Tools: {len(mcp_list_tools.tools)}") print("=" * 80) for i, tool in enumerate(mcp_list_tools.tools, 1): print(f"\n{i}. {tool.name}") print(f" Description: {tool.description}") # Parse and display input schema schema = tool.input_schema if schema and 'properties' in schema: properties = schema['properties'] required = schema.get('required', []) print(" Parameters:") for param_name, param_info in properties.items(): param_type = param_info.get('type', 'unknown') param_desc = param_info.get('description', 'No description') required_marker = " (required)" if param_name in required else " (optional)" print(f" • {param_name} ({param_type}){required_marker}") if param_desc: print(f" {param_desc}") if i < len(mcp_list_tools.tools): print("-" * 40) def main(): """Main function to run all the tests""" # Configuration LLAMA_STACK_URL = "http://localhost:8321/" LLAMA_STACK_MODEL_IDS = [ "openai/gpt-3.5-turbo", "openai/gpt-4o", "llama-openai-compat/Llama-3.3-70B-Instruct", "watsonx/meta-llama/llama-3-3-70b-instruct" ] # Using gpt-4o for this demo, but feel free to try one of the others or add more to run.yaml. OPENAI_MODEL_ID = LLAMA_STACK_MODEL_IDS[1] WATSONX_MODEL_ID = LLAMA_STACK_MODEL_IDS[-1] NPS_MCP_URL = "http://localhost:3005/sse/" print("=== Llama Stack Testing Script ===") print(f"Using OpenAI model: {OPENAI_MODEL_ID}") print(f"Using WatsonX model: {WATSONX_MODEL_ID}") print(f"MCP URL: {NPS_MCP_URL}") print() # Initialize client print("Initializing LlamaStackClient...") client = LlamaStackClient(base_url="http://localhost:8321") # Test 1: List models print("\n=== Test 1: List Models ===") try: models = client.models.list() print(f"Found {len(models)} models") except Exception as e: print(f"Error listing models: {e}") raise e # Test 2: Basic chat completion with OpenAI print("\n=== Test 2: Basic Chat Completion (OpenAI) ===") try: chat_completion_response = client.chat.completions.create( model=OPENAI_MODEL_ID, messages=[{"role": "user", "content": "What is the capital of France?"}] ) print("OpenAI Response:") for chunk in chat_completion_response.choices[0].message.content: print(chunk, end="", flush=True) print() except Exception as e: print(f"Error with OpenAI chat completion: {e}") raise e # Test 3: Basic chat completion with WatsonX print("\n=== Test 3: Basic Chat Completion (WatsonX) ===") try: chat_completion_response_wxai = client.chat.completions.create( model=WATSONX_MODEL_ID, messages=[{"role": "user", "content": "What is the capital of France?"}], ) print("WatsonX Response:") for chunk in chat_completion_response_wxai.choices[0].message.content: print(chunk, end="", flush=True) print() except Exception as e: print(f"Error with WatsonX chat completion: {e}") raise e # Test 4: Tool calling with OpenAI print("\n=== Test 4: Tool Calling (OpenAI) ===") tools = [ { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather for a specific location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g., San Francisco, CA", }, "unit": { "type": "string", "enum": ["celsius", "fahrenheit"] }, }, "required": ["location"], }, }, } ] messages = [ {"role": "user", "content": "What's the weather like in Boston, MA?"} ] try: print("--- Initial API Call ---") response = client.chat.completions.create( model=OPENAI_MODEL_ID, messages=messages, tools=tools, tool_choice="auto", # "auto" is the default ) print("OpenAI tool calling response received") except Exception as e: print(f"Error with OpenAI tool calling: {e}") raise e # Test 5: Tool calling with WatsonX print("\n=== Test 5: Tool Calling (WatsonX) ===") try: wxai_response = client.chat.completions.create( model=WATSONX_MODEL_ID, messages=messages, tools=tools, tool_choice="auto", # "auto" is the default ) print("WatsonX tool calling response received") except Exception as e: print(f"Error with WatsonX tool calling: {e}") raise e # Test 6: Streaming with WatsonX print("\n=== Test 6: Streaming Response (WatsonX) ===") try: chat_completion_response_wxai_stream = client.chat.completions.create( model=WATSONX_MODEL_ID, messages=[{"role": "user", "content": "What is the capital of France?"}], stream=True ) print("Model response: ", end="") for chunk in chat_completion_response_wxai_stream: # Each 'chunk' is a ChatCompletionChunk object. # We want the content from the 'delta' attribute. if hasattr(chunk, 'choices') and chunk.choices is not None: content = chunk.choices[0].delta.content # The first few chunks might have None content, so we check for it. if content is not None: print(content, end="", flush=True) print() except Exception as e: print(f"Error with streaming: {e}") raise e # Test 7: MCP with OpenAI print("\n=== Test 7: MCP Integration (OpenAI) ===") try: mcp_llama_stack_client_response = client.responses.create( model=OPENAI_MODEL_ID, input="Tell me about some parks in Rhode Island, and let me know if there are any upcoming events at them.", tools=[ { "type": "mcp", "server_url": NPS_MCP_URL, "server_label": "National Parks Service tools", "allowed_tools": ["search_parks", "get_park_events"], } ] ) print_response(mcp_llama_stack_client_response) except Exception as e: print(f"Error with MCP (OpenAI): {e}") raise e # Test 8: MCP with WatsonX print("\n=== Test 8: MCP Integration (WatsonX) ===") try: mcp_llama_stack_client_response = client.responses.create( model=WATSONX_MODEL_ID, input="What is the capital of France?" ) print_response(mcp_llama_stack_client_response) except Exception as e: print(f"Error with MCP (WatsonX): {e}") raise e # Test 9: MCP with Llama 3.3 print("\n=== Test 9: MCP Integration (Llama 3.3) ===") try: mcp_llama_stack_client_response = client.responses.create( model=WATSONX_MODEL_ID, input="Tell me about some parks in Rhode Island, and let me know if there are any upcoming events at them.", tools=[ { "type": "mcp", "server_url": NPS_MCP_URL, "server_label": "National Parks Service tools", "allowed_tools": ["search_parks", "get_park_events"], } ] ) print_response(mcp_llama_stack_client_response) except Exception as e: print(f"Error with MCP (Llama 3.3): {e}") raise e # Test 10: Embeddings print("\n=== Test 10: Embeddings ===") try: conn = http.client.HTTPConnection("localhost:8321") payload = json.dumps({ "model": "watsonx/ibm/granite-embedding-278m-multilingual", "input": "Hello, world!", }) headers = { 'Content-Type': 'application/json', 'Accept': 'application/json' } conn.request("POST", "/v1/openai/v1/embeddings", payload, headers) res = conn.getresponse() data = res.read() print(data.decode("utf-8")) except Exception as e: print(f"Error with Embeddings: {e}") raise e print("\n=== Testing Complete ===") if __name__ == "__main__": main() ``` --------- Signed-off-by: Bill Murdock <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 62bac0a commit 5d711d4

File tree

14 files changed

+214
-487
lines changed

14 files changed

+214
-487
lines changed

docs/docs/providers/inference/remote_watsonx.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
1717
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
1818
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
1919
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
20-
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key |
21-
| `project_id` | `str \| None` | No | | The Project ID key |
20+
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key |
21+
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
2222
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
2323

2424
## Sample Configuration

llama_stack/core/routers/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ async def stream_tokens_and_compute_metrics_openai_chat(
611611
completion_text += "".join(choice_data["content_parts"])
612612

613613
# Add metrics to the chunk
614-
if self.telemetry and chunk.usage:
614+
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
615615
metrics = self._construct_metrics(
616616
prompt_tokens=chunk.usage.prompt_tokens,
617617
completion_tokens=chunk.usage.completion_tokens,

llama_stack/distributions/watsonx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6+
7+
from .watsonx import get_distribution_template # noqa: F401

llama_stack/distributions/watsonx/build.yaml

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,33 @@ distribution_spec:
33
description: Use watsonx for running LLM inference
44
providers:
55
inference:
6-
- provider_id: watsonx
7-
provider_type: remote::watsonx
8-
- provider_id: sentence-transformers
9-
provider_type: inline::sentence-transformers
6+
- provider_type: remote::watsonx
7+
- provider_type: inline::sentence-transformers
108
vector_io:
11-
- provider_id: faiss
12-
provider_type: inline::faiss
9+
- provider_type: inline::faiss
1310
safety:
14-
- provider_id: llama-guard
15-
provider_type: inline::llama-guard
11+
- provider_type: inline::llama-guard
1612
agents:
17-
- provider_id: meta-reference
18-
provider_type: inline::meta-reference
13+
- provider_type: inline::meta-reference
1914
telemetry:
20-
- provider_id: meta-reference
21-
provider_type: inline::meta-reference
15+
- provider_type: inline::meta-reference
2216
eval:
23-
- provider_id: meta-reference
24-
provider_type: inline::meta-reference
17+
- provider_type: inline::meta-reference
2518
datasetio:
26-
- provider_id: huggingface
27-
provider_type: remote::huggingface
28-
- provider_id: localfs
29-
provider_type: inline::localfs
19+
- provider_type: remote::huggingface
20+
- provider_type: inline::localfs
3021
scoring:
31-
- provider_id: basic
32-
provider_type: inline::basic
33-
- provider_id: llm-as-judge
34-
provider_type: inline::llm-as-judge
35-
- provider_id: braintrust
36-
provider_type: inline::braintrust
22+
- provider_type: inline::basic
23+
- provider_type: inline::llm-as-judge
24+
- provider_type: inline::braintrust
3725
tool_runtime:
3826
- provider_type: remote::brave-search
3927
- provider_type: remote::tavily-search
4028
- provider_type: inline::rag-runtime
4129
- provider_type: remote::model-context-protocol
30+
files:
31+
- provider_type: inline::localfs
4232
image_type: venv
4333
additional_pip_packages:
44-
- sqlalchemy[asyncio]
45-
- aiosqlite
4634
- aiosqlite
35+
- sqlalchemy[asyncio]

llama_stack/distributions/watsonx/run.yaml

Lines changed: 3 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ apis:
44
- agents
55
- datasetio
66
- eval
7+
- files
78
- inference
89
- safety
910
- scoring
1011
- telemetry
1112
- tool_runtime
1213
- vector_io
13-
- files
1414
providers:
1515
inference:
1616
- provider_id: watsonx
@@ -19,8 +19,6 @@ providers:
1919
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
2020
api_key: ${env.WATSONX_API_KEY:=}
2121
project_id: ${env.WATSONX_PROJECT_ID:=}
22-
- provider_id: sentence-transformers
23-
provider_type: inline::sentence-transformers
2422
vector_io:
2523
- provider_id: faiss
2624
provider_type: inline::faiss
@@ -48,7 +46,7 @@ providers:
4846
provider_type: inline::meta-reference
4947
config:
5048
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
51-
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
49+
sinks: ${env.TELEMETRY_SINKS:=sqlite}
5250
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
5351
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
5452
eval:
@@ -109,102 +107,7 @@ metadata_store:
109107
inference_store:
110108
type: sqlite
111109
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
112-
models:
113-
- metadata: {}
114-
model_id: meta-llama/llama-3-3-70b-instruct
115-
provider_id: watsonx
116-
provider_model_id: meta-llama/llama-3-3-70b-instruct
117-
model_type: llm
118-
- metadata: {}
119-
model_id: meta-llama/Llama-3.3-70B-Instruct
120-
provider_id: watsonx
121-
provider_model_id: meta-llama/llama-3-3-70b-instruct
122-
model_type: llm
123-
- metadata: {}
124-
model_id: meta-llama/llama-2-13b-chat
125-
provider_id: watsonx
126-
provider_model_id: meta-llama/llama-2-13b-chat
127-
model_type: llm
128-
- metadata: {}
129-
model_id: meta-llama/Llama-2-13b
130-
provider_id: watsonx
131-
provider_model_id: meta-llama/llama-2-13b-chat
132-
model_type: llm
133-
- metadata: {}
134-
model_id: meta-llama/llama-3-1-70b-instruct
135-
provider_id: watsonx
136-
provider_model_id: meta-llama/llama-3-1-70b-instruct
137-
model_type: llm
138-
- metadata: {}
139-
model_id: meta-llama/Llama-3.1-70B-Instruct
140-
provider_id: watsonx
141-
provider_model_id: meta-llama/llama-3-1-70b-instruct
142-
model_type: llm
143-
- metadata: {}
144-
model_id: meta-llama/llama-3-1-8b-instruct
145-
provider_id: watsonx
146-
provider_model_id: meta-llama/llama-3-1-8b-instruct
147-
model_type: llm
148-
- metadata: {}
149-
model_id: meta-llama/Llama-3.1-8B-Instruct
150-
provider_id: watsonx
151-
provider_model_id: meta-llama/llama-3-1-8b-instruct
152-
model_type: llm
153-
- metadata: {}
154-
model_id: meta-llama/llama-3-2-11b-vision-instruct
155-
provider_id: watsonx
156-
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
157-
model_type: llm
158-
- metadata: {}
159-
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
160-
provider_id: watsonx
161-
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
162-
model_type: llm
163-
- metadata: {}
164-
model_id: meta-llama/llama-3-2-1b-instruct
165-
provider_id: watsonx
166-
provider_model_id: meta-llama/llama-3-2-1b-instruct
167-
model_type: llm
168-
- metadata: {}
169-
model_id: meta-llama/Llama-3.2-1B-Instruct
170-
provider_id: watsonx
171-
provider_model_id: meta-llama/llama-3-2-1b-instruct
172-
model_type: llm
173-
- metadata: {}
174-
model_id: meta-llama/llama-3-2-3b-instruct
175-
provider_id: watsonx
176-
provider_model_id: meta-llama/llama-3-2-3b-instruct
177-
model_type: llm
178-
- metadata: {}
179-
model_id: meta-llama/Llama-3.2-3B-Instruct
180-
provider_id: watsonx
181-
provider_model_id: meta-llama/llama-3-2-3b-instruct
182-
model_type: llm
183-
- metadata: {}
184-
model_id: meta-llama/llama-3-2-90b-vision-instruct
185-
provider_id: watsonx
186-
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
187-
model_type: llm
188-
- metadata: {}
189-
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
190-
provider_id: watsonx
191-
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
192-
model_type: llm
193-
- metadata: {}
194-
model_id: meta-llama/llama-guard-3-11b-vision
195-
provider_id: watsonx
196-
provider_model_id: meta-llama/llama-guard-3-11b-vision
197-
model_type: llm
198-
- metadata: {}
199-
model_id: meta-llama/Llama-Guard-3-11B-Vision
200-
provider_id: watsonx
201-
provider_model_id: meta-llama/llama-guard-3-11b-vision
202-
model_type: llm
203-
- metadata:
204-
embedding_dimension: 384
205-
model_id: all-MiniLM-L6-v2
206-
provider_id: sentence-transformers
207-
model_type: embedding
110+
models: []
208111
shields: []
209112
vector_dbs: []
210113
datasets: []

llama_stack/distributions/watsonx/watsonx.py

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,11 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from pathlib import Path
87

9-
from llama_stack.apis.models import ModelType
10-
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
11-
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
8+
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
9+
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
1210
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
13-
from llama_stack.providers.inline.inference.sentence_transformers import (
14-
SentenceTransformersInferenceConfig,
15-
)
1611
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
17-
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
1812

1913

2014
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
@@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
5246
config=WatsonXConfig.sample_run_config(),
5347
)
5448

55-
embedding_provider = Provider(
56-
provider_id="sentence-transformers",
57-
provider_type="inline::sentence-transformers",
58-
config=SentenceTransformersInferenceConfig.sample_run_config(),
59-
)
60-
61-
available_models = {
62-
"watsonx": MODEL_ENTRIES,
63-
}
6449
default_tool_groups = [
6550
ToolGroupInput(
6651
toolgroup_id="builtin::websearch",
@@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
7257
),
7358
]
7459

75-
embedding_model = ModelInput(
76-
model_id="all-MiniLM-L6-v2",
77-
provider_id="sentence-transformers",
78-
model_type=ModelType.embedding,
79-
metadata={
80-
"embedding_dimension": 384,
81-
},
82-
)
83-
8460
files_provider = Provider(
8561
provider_id="meta-reference-files",
8662
provider_type="inline::localfs",
8763
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
8864
)
89-
default_models, _ = get_model_registry(available_models)
9065
return DistributionTemplate(
9166
name=name,
9267
distro_type="remote_hosted",
9368
description="Use watsonx for running LLM inference",
9469
container_image=None,
95-
template_path=Path(__file__).parent / "doc_template.md",
70+
template_path=None,
9671
providers=providers,
97-
available_models_by_provider=available_models,
9872
run_configs={
9973
"run.yaml": RunConfigSettings(
10074
provider_overrides={
101-
"inference": [inference_provider, embedding_provider],
75+
"inference": [inference_provider],
10276
"files": [files_provider],
10377
},
104-
default_models=default_models + [embedding_model],
78+
default_models=[],
10579
default_tool_groups=default_tool_groups,
10680
),
10781
},

llama_stack/providers/registry/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def available_providers() -> list[ProviderSpec]:
268268
api=Api.inference,
269269
adapter_type="watsonx",
270270
provider_type="remote::watsonx",
271-
pip_packages=["ibm_watsonx_ai"],
271+
pip_packages=["litellm"],
272272
module="llama_stack.providers.remote.inference.watsonx",
273273
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
274274
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",

llama_stack/providers/remote/inference/watsonx/__init__.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,12 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.apis.inference import Inference
8-
97
from .config import WatsonXConfig
108

119

12-
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
13-
# import dynamically so `llama stack build` does not fail due to missing dependencies
10+
async def get_adapter_impl(config: WatsonXConfig, _deps):
11+
# import dynamically so the import is used only when it is needed
1412
from .watsonx import WatsonXInferenceAdapter
1513

16-
if not isinstance(config, WatsonXConfig):
17-
raise RuntimeError(f"Unexpected config type: {type(config)}")
1814
adapter = WatsonXInferenceAdapter(config)
1915
return adapter
20-
21-
22-
__all__ = ["get_adapter_impl", "WatsonXConfig"]

llama_stack/providers/remote/inference/watsonx/config.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77
import os
88
from typing import Any
99

10-
from pydantic import BaseModel, Field, SecretStr
10+
from pydantic import BaseModel, ConfigDict, Field, SecretStr
1111

1212
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
1313
from llama_stack.schema_utils import json_schema_type
1414

1515

1616
class WatsonXProviderDataValidator(BaseModel):
17-
url: str
18-
api_key: str
19-
project_id: str
17+
model_config = ConfigDict(
18+
from_attributes=True,
19+
extra="forbid",
20+
)
21+
watsonx_api_key: str | None
2022

2123

2224
@json_schema_type
@@ -25,13 +27,17 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
2527
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
2628
description="A base url for accessing the watsonx.ai",
2729
)
30+
# This seems like it should be required, but none of the other remote inference
31+
# providers require it, so this is optional here too for consistency.
32+
# The OpenAIConfig uses default=None instead, so this is following that precedent.
2833
api_key: SecretStr | None = Field(
29-
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
30-
description="The watsonx API key",
34+
default=None,
35+
description="The watsonx.ai API key",
3136
)
37+
# As above, this is optional here too for consistency.
3238
project_id: str | None = Field(
33-
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
34-
description="The Project ID key",
39+
default=None,
40+
description="The watsonx.ai project ID",
3541
)
3642
timeout: int = Field(
3743
default=60,

0 commit comments

Comments
 (0)