diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index f895658928..6b69038c67 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -55,6 +55,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=[], module="llama_stack.providers.remote.inference.cerebras", config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator", description="Cerebras inference provider for running models on Cerebras Cloud platform.", ), RemoteProviderSpec( @@ -143,6 +144,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=["databricks-sdk"], module="llama_stack.providers.remote.inference.databricks", config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator", description="Databricks inference provider for running models on Databricks' unified analytics platform.", ), RemoteProviderSpec( @@ -152,6 +154,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=[], module="llama_stack.providers.remote.inference.nvidia", config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", + provider_data_validator="llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator", description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", ), RemoteProviderSpec( @@ -161,6 +164,7 @@ def available_providers() -> list[ProviderSpec]: pip_packages=[], module="llama_stack.providers.remote.inference.runpod", config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator", description="RunPod inference provider for running models on RunPod's cloud GPU platform.", ), RemoteProviderSpec( diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 11ef218a12..291336f864 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -15,6 +15,8 @@ class CerebrasInferenceAdapter(OpenAIMixin): config: CerebrasImplConfig + provider_data_api_key_field: str = "cerebras_api_key" + def get_api_key(self) -> str: return self.config.api_key.get_secret_value() diff --git a/llama_stack/providers/remote/inference/cerebras/config.py b/llama_stack/providers/remote/inference/cerebras/config.py index 40db38935d..dbab60a4ba 100644 --- a/llama_stack/providers/remote/inference/cerebras/config.py +++ b/llama_stack/providers/remote/inference/cerebras/config.py @@ -7,7 +7,7 @@ import os from typing import Any -from pydantic import Field, SecretStr +from pydantic import BaseModel, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @@ -15,6 +15,13 @@ DEFAULT_BASE_URL = "https://api.cerebras.ai" +class CerebrasProviderDataValidator(BaseModel): + cerebras_api_key: str | None = Field( + default=None, + description="API key for Cerebras models", + ) + + @json_schema_type class CerebrasImplConfig(RemoteInferenceProviderConfig): base_url: str = Field( diff --git a/llama_stack/providers/remote/inference/databricks/config.py b/llama_stack/providers/remote/inference/databricks/config.py index 68e94151ea..279e741be6 100644 --- a/llama_stack/providers/remote/inference/databricks/config.py +++ b/llama_stack/providers/remote/inference/databricks/config.py @@ -6,12 +6,19 @@ from typing import Any -from pydantic import Field, SecretStr +from pydantic import BaseModel, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type +class DatabricksProviderDataValidator(BaseModel): + databricks_api_token: str | None = Field( + default=None, + description="API token for Databricks models", + ) + + @json_schema_type class DatabricksImplConfig(RemoteInferenceProviderConfig): url: str | None = Field( diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 200b361711..cf7c729245 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -21,6 +21,8 @@ class DatabricksInferenceAdapter(OpenAIMixin): config: DatabricksImplConfig + provider_data_api_key_field: str = "databricks_api_token" + # source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models embedding_model_metadata: dict[str, dict[str, int]] = { "databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192}, diff --git a/llama_stack/providers/remote/inference/nvidia/config.py b/llama_stack/providers/remote/inference/nvidia/config.py index 4b310d770c..df623934bd 100644 --- a/llama_stack/providers/remote/inference/nvidia/config.py +++ b/llama_stack/providers/remote/inference/nvidia/config.py @@ -7,12 +7,19 @@ import os from typing import Any -from pydantic import Field, SecretStr +from pydantic import BaseModel, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type +class NVIDIAProviderDataValidator(BaseModel): + nvidia_api_key: str | None = Field( + default=None, + description="API key for NVIDIA NIM models", + ) + + @json_schema_type class NVIDIAConfig(RemoteInferenceProviderConfig): """ diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 7a2697327b..d30d8b0e1e 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -24,6 +24,8 @@ class NVIDIAInferenceAdapter(OpenAIMixin): config: NVIDIAConfig + provider_data_api_key_field: str = "nvidia_api_key" + """ NVIDIA Inference Adapter for Llama Stack. diff --git a/llama_stack/providers/remote/inference/runpod/config.py b/llama_stack/providers/remote/inference/runpod/config.py index cdfe0f885a..93db2d0f56 100644 --- a/llama_stack/providers/remote/inference/runpod/config.py +++ b/llama_stack/providers/remote/inference/runpod/config.py @@ -6,12 +6,19 @@ from typing import Any -from pydantic import Field +from pydantic import BaseModel, Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type +class RunpodProviderDataValidator(BaseModel): + runpod_api_token: str | None = Field( + default=None, + description="API token for RunPod models", + ) + + @json_schema_type class RunpodImplConfig(RemoteInferenceProviderConfig): url: str | None = Field( diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index f752740e5d..6d5968f82e 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -24,6 +24,8 @@ class RunpodInferenceAdapter(OpenAIMixin): config: RunpodImplConfig + provider_data_api_key_field: str = "runpod_api_token" + def get_api_key(self) -> str: """Get API key for OpenAI client.""" return self.config.api_token diff --git a/pyproject.toml b/pyproject.toml index 5f086bd9d3..63108349bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,8 @@ dev = [ ] # These are the dependencies required for running unit tests. unit = [ + "anthropic", + "databricks-sdk", "sqlite-vec", "ollama", "aiosqlite", diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py index 55a6793c2b..aa3a2c77a5 100644 --- a/tests/unit/providers/inference/test_inference_client_caching.py +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -10,47 +10,124 @@ import pytest from llama_stack.core.request_headers import request_provider_data_context +from llama_stack.providers.remote.inference.anthropic.anthropic import AnthropicInferenceAdapter +from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig +from llama_stack.providers.remote.inference.cerebras.cerebras import CerebrasInferenceAdapter +from llama_stack.providers.remote.inference.cerebras.config import CerebrasImplConfig +from llama_stack.providers.remote.inference.databricks.config import DatabricksImplConfig +from llama_stack.providers.remote.inference.databricks.databricks import DatabricksInferenceAdapter +from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig +from llama_stack.providers.remote.inference.fireworks.fireworks import FireworksInferenceAdapter +from llama_stack.providers.remote.inference.gemini.config import GeminiConfig +from llama_stack.providers.remote.inference.gemini.gemini import GeminiInferenceAdapter from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter +from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig +from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter +from llama_stack.providers.remote.inference.runpod.config import RunpodImplConfig +from llama_stack.providers.remote.inference.runpod.runpod import RunpodInferenceAdapter +from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig +from llama_stack.providers.remote.inference.sambanova.sambanova import SambaNovaInferenceAdapter from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter +from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig +from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter @pytest.mark.parametrize( - "config_cls,adapter_cls,provider_data_validator", + "config_cls,adapter_cls,provider_data_validator,config_params", [ ( GroqConfig, GroqInferenceAdapter, "llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", + {}, ), ( OpenAIConfig, OpenAIInferenceAdapter, "llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", + {}, ), ( TogetherImplConfig, TogetherInferenceAdapter, "llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", + {}, ), ( LlamaCompatConfig, LlamaCompatInferenceAdapter, "llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", + {}, + ), + ( + CerebrasImplConfig, + CerebrasInferenceAdapter, + "llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator", + {}, + ), + ( + DatabricksImplConfig, + DatabricksInferenceAdapter, + "llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator", + {}, + ), + ( + NVIDIAConfig, + NVIDIAInferenceAdapter, + "llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator", + {}, + ), + ( + RunpodImplConfig, + RunpodInferenceAdapter, + "llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator", + {}, + ), + ( + FireworksImplConfig, + FireworksInferenceAdapter, + "llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator", + {}, + ), + ( + AnthropicConfig, + AnthropicInferenceAdapter, + "llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", + {}, + ), + ( + GeminiConfig, + GeminiInferenceAdapter, + "llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", + {}, + ), + ( + SambaNovaImplConfig, + SambaNovaInferenceAdapter, + "llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", + {}, + ), + ( + VLLMInferenceAdapterConfig, + VLLMInferenceAdapter, + "llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", + { + "url": "http://fake", + }, ), ], ) -def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str): +def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str, config_params: dict): """Ensure the OpenAI provider does not cache api keys across client requests""" - - inference_adapter = adapter_cls(config=config_cls()) + inference_adapter = adapter_cls(config=config_cls(**config_params)) inference_adapter.__provider_spec__ = MagicMock() inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator diff --git a/uv.lock b/uv.lock index fea1d40c91..d560b06b27 100644 --- a/uv.lock +++ b/uv.lock @@ -129,6 +129,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "anthropic" +version = "0.69.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "docstring-parser" }, + { name = "httpx" }, + { name = "jiter" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c8/9d/9ad1778b95f15c5b04e7d328c1b5f558f1e893857b7c33cd288c19c0057a/anthropic-0.69.0.tar.gz", hash = "sha256:c604d287f4d73640f40bd2c0f3265a2eb6ce034217ead0608f6b07a8bc5ae5f2", size = 480622, upload-time = "2025-09-29T16:53:45.282Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/38/75129688de5637eb5b383e5f2b1570a5cc3aecafa4de422da8eea4b90a6c/anthropic-0.69.0-py3-none-any.whl", hash = "sha256:1f73193040f33f11e27c2cd6ec25f24fe7c3f193dc1c5cde6b7a08b18a16bcc5", size = 337265, upload-time = "2025-09-29T16:53:43.686Z" }, +] + [[package]] name = "anyio" version = "4.9.0" @@ -741,6 +760,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/b3/28ac139109d9005ad3f6b6f8976ffede6706a6478e21c889ce36c840918e/cryptography-45.0.5-cp37-abi3-win_amd64.whl", hash = "sha256:90cb0a7bb35959f37e23303b7eed0a32280510030daba3f7fdfbb65defde6a97", size = 3390016, upload-time = "2025-07-02T13:05:50.811Z" }, ] +[[package]] +name = "databricks-sdk" +version = "0.67.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/5b/df3e5424d833e4f3f9b42c409ef8b513e468c9cdf06c2a9935c6cbc4d128/databricks_sdk-0.67.0.tar.gz", hash = "sha256:f923227babcaad428b0c2eede2755ebe9deb996e2c8654f179eb37f486b37a36", size = 761000, upload-time = "2025-09-25T13:32:10.858Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/ca/2aff3817041483fb8e4f75a74a36ff4ca3a826e276becd1179a591b6348f/databricks_sdk-0.67.0-py3-none-any.whl", hash = "sha256:ef49e49db45ed12c015a32a6f9d4ba395850f25bb3dcffdcaf31a5167fe03ee2", size = 718422, upload-time = "2025-09-25T13:32:09.011Z" }, +] + [[package]] name = "datasets" version = "4.0.0" @@ -839,6 +871,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, ] +[[package]] +name = "docstring-parser" +version = "0.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, +] + [[package]] name = "docutils" version = "0.21.2" @@ -1202,17 +1243,16 @@ wheels = [ [[package]] name = "google-auth" -version = "1.6.3" +version = "2.41.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, { name = "pyasn1-modules" }, { name = "rsa" }, - { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ef/77/eb1d3288dbe2ba6f4fe50b9bb41770bac514cd2eb91466b56d44a99e2f8d/google-auth-1.6.3.tar.gz", hash = "sha256:0f7c6a64927d34c1a474da92cfc59e552a5d3b940d3266606c6a28b72888b9e4", size = 80899, upload-time = "2019-02-19T21:14:58.34Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/af/5129ce5b2f9688d2fa49b463e544972a7c82b0fdb50980dafee92e121d9f/google_auth-2.41.1.tar.gz", hash = "sha256:b76b7b1f9e61f0cb7e88870d14f6a94aeef248959ef6992670efee37709cbfd2", size = 292284, upload-time = "2025-09-30T22:51:26.363Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c5/9b/ed0516cc1f7609fb0217e3057ff4f0f9f3e3ce79a369c6af4a6c5ca25664/google_auth-1.6.3-py2.py3-none-any.whl", hash = "sha256:20705f6803fd2c4d1cc2dcb0df09d4dfcb9a7d51fd59e94a3a28231fd93119ed", size = 73441, upload-time = "2019-02-19T21:14:56.623Z" }, + { url = "https://files.pythonhosted.org/packages/be/a4/7319a2a8add4cc352be9e3efeff5e2aacee917c85ca2fa1647e29089983c/google_auth-2.41.1-py2.py3-none-any.whl", hash = "sha256:754843be95575b9a19c604a848a41be03f7f2afd8c019f716dc1f51ee41c639d", size = 221302, upload-time = "2025-09-30T22:51:24.212Z" }, ] [[package]] @@ -1855,10 +1895,12 @@ test = [ unit = [ { name = "aiohttp" }, { name = "aiosqlite" }, + { name = "anthropic" }, { name = "blobfile" }, { name = "chardet" }, { name = "chromadb" }, { name = "coverage" }, + { name = "databricks-sdk" }, { name = "faiss-cpu" }, { name = "litellm" }, { name = "mcp" }, @@ -1974,10 +2016,12 @@ test = [ unit = [ { name = "aiohttp" }, { name = "aiosqlite" }, + { name = "anthropic" }, { name = "blobfile" }, { name = "chardet" }, { name = "chromadb", specifier = ">=1.0.15" }, { name = "coverage" }, + { name = "databricks-sdk" }, { name = "faiss-cpu" }, { name = "litellm" }, { name = "mcp" },