Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=[],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
),
RemoteProviderSpec(
Expand Down Expand Up @@ -143,6 +144,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=["databricks-sdk"],
module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
),
RemoteProviderSpec(
Expand All @@ -152,6 +154,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=[],
module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
provider_data_validator="llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
),
RemoteProviderSpec(
Expand All @@ -161,6 +164,7 @@ def available_providers() -> list[ProviderSpec]:
pip_packages=[],
module="llama_stack.providers.remote.inference.runpod",
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
),
RemoteProviderSpec(
Expand Down
2 changes: 2 additions & 0 deletions llama_stack/providers/remote/inference/cerebras/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig

provider_data_api_key_field: str = "cerebras_api_key"

def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()

Expand Down
9 changes: 8 additions & 1 deletion llama_stack/providers/remote/inference/cerebras/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,21 @@
import os
from typing import Any

from pydantic import Field, SecretStr
from pydantic import BaseModel, Field, SecretStr

from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type

DEFAULT_BASE_URL = "https://api.cerebras.ai"


class CerebrasProviderDataValidator(BaseModel):
cerebras_api_key: str | None = Field(
default=None,
description="API key for Cerebras models",
)


@json_schema_type
class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@

from typing import Any

from pydantic import Field, SecretStr
from pydantic import BaseModel, Field, SecretStr

from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type


class DatabricksProviderDataValidator(BaseModel):
databricks_api_token: str | None = Field(
default=None,
description="API token for Databricks models",
)


@json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
class DatabricksInferenceAdapter(OpenAIMixin):
config: DatabricksImplConfig

provider_data_api_key_field: str = "databricks_api_token"

# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata: dict[str, dict[str, int]] = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
Expand Down
9 changes: 8 additions & 1 deletion llama_stack/providers/remote/inference/nvidia/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@
import os
from typing import Any

from pydantic import Field, SecretStr
from pydantic import BaseModel, Field, SecretStr

from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type


class NVIDIAProviderDataValidator(BaseModel):
nvidia_api_key: str | None = Field(
default=None,
description="API key for NVIDIA NIM models",
)


@json_schema_type
class NVIDIAConfig(RemoteInferenceProviderConfig):
"""
Expand Down
2 changes: 2 additions & 0 deletions llama_stack/providers/remote/inference/nvidia/nvidia.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
class NVIDIAInferenceAdapter(OpenAIMixin):
config: NVIDIAConfig

provider_data_api_key_field: str = "nvidia_api_key"

"""
NVIDIA Inference Adapter for Llama Stack.

Expand Down
9 changes: 8 additions & 1 deletion llama_stack/providers/remote/inference/runpod/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@

from typing import Any

from pydantic import Field
from pydantic import BaseModel, Field

from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type


class RunpodProviderDataValidator(BaseModel):
runpod_api_token: str | None = Field(
default=None,
description="API token for RunPod models",
)


@json_schema_type
class RunpodImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
Expand Down
2 changes: 2 additions & 0 deletions llama_stack/providers/remote/inference/runpod/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class RunpodInferenceAdapter(OpenAIMixin):

config: RunpodImplConfig

provider_data_api_key_field: str = "runpod_api_token"

def get_api_key(self) -> str:
"""Get API key for OpenAI client."""
return self.config.api_token
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ dev = [
]
# These are the dependencies required for running unit tests.
unit = [
"anthropic",
"databricks-sdk",
"sqlite-vec",
"ollama",
"aiosqlite",
Expand Down
85 changes: 81 additions & 4 deletions tests/unit/providers/inference/test_inference_client_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,124 @@
import pytest

from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.anthropic.anthropic import AnthropicInferenceAdapter
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
from llama_stack.providers.remote.inference.cerebras.cerebras import CerebrasInferenceAdapter
from llama_stack.providers.remote.inference.cerebras.config import CerebrasImplConfig
from llama_stack.providers.remote.inference.databricks.config import DatabricksImplConfig
from llama_stack.providers.remote.inference.databricks.databricks import DatabricksInferenceAdapter
from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.fireworks import FireworksInferenceAdapter
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
from llama_stack.providers.remote.inference.gemini.gemini import GeminiInferenceAdapter
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.runpod.config import RunpodImplConfig
from llama_stack.providers.remote.inference.runpod.runpod import RunpodInferenceAdapter
from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig
from llama_stack.providers.remote.inference.sambanova.sambanova import SambaNovaInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter


@pytest.mark.parametrize(
"config_cls,adapter_cls,provider_data_validator",
"config_cls,adapter_cls,provider_data_validator,config_params",
[
(
GroqConfig,
GroqInferenceAdapter,
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
{},
),
(
OpenAIConfig,
OpenAIInferenceAdapter,
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
{},
),
(
TogetherImplConfig,
TogetherInferenceAdapter,
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
{},
),
(
LlamaCompatConfig,
LlamaCompatInferenceAdapter,
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
{},
),
(
CerebrasImplConfig,
CerebrasInferenceAdapter,
"llama_stack.providers.remote.inference.cerebras.config.CerebrasProviderDataValidator",
{},
),
(
DatabricksImplConfig,
DatabricksInferenceAdapter,
"llama_stack.providers.remote.inference.databricks.config.DatabricksProviderDataValidator",
{},
),
(
NVIDIAConfig,
NVIDIAInferenceAdapter,
"llama_stack.providers.remote.inference.nvidia.config.NVIDIAProviderDataValidator",
{},
),
(
RunpodImplConfig,
RunpodInferenceAdapter,
"llama_stack.providers.remote.inference.runpod.config.RunpodProviderDataValidator",
{},
),
(
FireworksImplConfig,
FireworksInferenceAdapter,
"llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
{},
),
(
AnthropicConfig,
AnthropicInferenceAdapter,
"llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
{},
),
(
GeminiConfig,
GeminiInferenceAdapter,
"llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
{},
),
(
SambaNovaImplConfig,
SambaNovaInferenceAdapter,
"llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
{},
),
(
VLLMInferenceAdapterConfig,
VLLMInferenceAdapter,
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
{
"url": "http://fake",
},
),
],
)
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str, config_params: dict):
"""Ensure the OpenAI provider does not cache api keys across client requests"""

inference_adapter = adapter_cls(config=config_cls())
inference_adapter = adapter_cls(config=config_cls(**config_params))

inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
Expand Down
Loading
Loading