Skip to content

Commit 59d4cfa

Browse files
committed
feat: add OpenAI-compatible Bedrock provider
Implements AWS Bedrock inference provider using OpenAI-compatible endpoint for Llama models available through Bedrock. Changes: - Add BedrockInferenceAdapter using LiteLLMOpenAIMixin base - Configure region-specific endpoint URLs - Support cross-region inference profiles with retry logic - Implement comprehensive unit tests and integration tests - Add provider registry configuration with litellm dependency
1 parent 96886af commit 59d4cfa

File tree

14 files changed

+530
-188
lines changed

14 files changed

+530
-188
lines changed
Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
description: "AWS Bedrock inference provider for accessing various AI models through AWS's managed service."
2+
description: "AWS Bedrock inference provider using OpenAI compatible endpoint."
33
sidebar_label: Remote - Bedrock
44
title: remote::bedrock
55
---
@@ -8,27 +8,20 @@ title: remote::bedrock
88

99
## Description
1010

11-
AWS Bedrock inference provider for accessing various AI models through AWS's managed service.
11+
AWS Bedrock inference provider using OpenAI compatible endpoint.
1212

1313
## Configuration
1414

1515
| Field | Type | Required | Default | Description |
1616
|-------|------|----------|---------|-------------|
1717
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
1818
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
19-
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
20-
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
21-
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
22-
| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION |
23-
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
24-
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
25-
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
26-
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
27-
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
28-
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
19+
| `api_key` | `str \| None` | No | | Amazon Bedrock API key |
20+
| `region_name` | `<class 'str'>` | No | us-east-2 | AWS Region for the Bedrock Runtime endpoint |
2921

3022
## Sample Configuration
3123

3224
```yaml
33-
{}
25+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
26+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
3427
```

llama_stack/distributions/ci-tests/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ providers:
4747
api_key: ${env.TOGETHER_API_KEY:=}
4848
- provider_id: bedrock
4949
provider_type: remote::bedrock
50+
config:
51+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
52+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
5053
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
5154
provider_type: remote::nvidia
5255
config:

llama_stack/distributions/starter-gpu/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ providers:
4747
api_key: ${env.TOGETHER_API_KEY:=}
4848
- provider_id: bedrock
4949
provider_type: remote::bedrock
50+
config:
51+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
52+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
5053
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
5154
provider_type: remote::nvidia
5255
config:

llama_stack/distributions/starter/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ providers:
4747
api_key: ${env.TOGETHER_API_KEY:=}
4848
- provider_id: bedrock
4949
provider_type: remote::bedrock
50+
config:
51+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
52+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
5053
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
5154
provider_type: remote::nvidia
5255
config:

llama_stack/providers/registry/inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,11 @@ def available_providers() -> list[ProviderSpec]:
131131
api=Api.inference,
132132
adapter_type="bedrock",
133133
provider_type="remote::bedrock",
134-
pip_packages=["boto3"],
134+
pip_packages=["litellm"],
135135
module="llama_stack.providers.remote.inference.bedrock",
136136
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
137-
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
137+
provider_data_validator="llama_stack.providers.remote.inference.bedrock.config.BedrockProviderDataValidator",
138+
description="AWS Bedrock inference provider using OpenAI compatible endpoint.",
138139
),
139140
RemoteProviderSpec(
140141
api=Api.inference,

llama_stack/providers/remote/inference/bedrock/bedrock.py

Lines changed: 83 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -4,162 +4,55 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
import json
87
from collections.abc import AsyncIterator
9-
from typing import Any
8+
from typing import Any, cast
109

11-
from botocore.client import BaseClient
10+
from openai import AsyncOpenAI, AuthenticationError, BadRequestError, NotFoundError
1211

1312
from llama_stack.apis.inference import (
14-
ChatCompletionRequest,
15-
Inference,
16-
OpenAIEmbeddingsResponse,
17-
)
18-
from llama_stack.apis.inference.inference import (
1913
OpenAIChatCompletion,
2014
OpenAIChatCompletionChunk,
21-
OpenAICompletion,
2215
OpenAIMessageParam,
2316
OpenAIResponseFormatParam,
2417
)
18+
from llama_stack.log import get_logger
2519
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
26-
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
27-
from llama_stack.providers.utils.inference.model_registry import (
28-
ModelRegistryHelper,
29-
)
20+
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
3021
from llama_stack.providers.utils.inference.openai_compat import (
31-
get_sampling_strategy_options,
32-
)
33-
from llama_stack.providers.utils.inference.prompt_adapter import (
34-
chat_completion_request_to_prompt,
22+
prepare_openai_completion_params,
3523
)
24+
from llama_stack.providers.utils.telemetry.tracing import get_current_span
3625

3726
from .models import MODEL_ENTRIES
3827

39-
REGION_PREFIX_MAP = {
40-
"us": "us.",
41-
"eu": "eu.",
42-
"ap": "ap.",
43-
}
44-
45-
46-
def _get_region_prefix(region: str | None) -> str:
47-
# AWS requires region prefixes for inference profiles
48-
if region is None:
49-
return "us." # default to US when we don't know
50-
51-
# Handle case insensitive region matching
52-
region_lower = region.lower()
53-
for prefix in REGION_PREFIX_MAP:
54-
if region_lower.startswith(f"{prefix}-"):
55-
return REGION_PREFIX_MAP[prefix]
56-
57-
# Fallback to US for anything we don't recognize
58-
return "us."
59-
60-
61-
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
62-
# Return ARNs unchanged
63-
if model_id.startswith("arn:"):
64-
return model_id
65-
66-
# Return inference profile IDs that already have regional prefixes
67-
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
68-
return model_id
28+
logger = get_logger(__name__)
6929

70-
# Default to US East when no region is provided
71-
if region is None:
72-
region = "us-east-1"
7330

74-
return _get_region_prefix(region) + model_id
31+
class BedrockInferenceAdapter(LiteLLMOpenAIMixin):
32+
_config: BedrockConfig
7533

76-
77-
class BedrockInferenceAdapter(
78-
ModelRegistryHelper,
79-
Inference,
80-
):
8134
def __init__(self, config: BedrockConfig) -> None:
82-
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
35+
LiteLLMOpenAIMixin.__init__(
36+
self,
37+
model_entries=MODEL_ENTRIES,
38+
litellm_provider_name="openai",
39+
api_key_from_config=config.api_key,
40+
provider_data_api_key_field="aws_bedrock_api_key", # Fixed: Match validator field name
41+
openai_compat_api_base=f"https://bedrock-runtime.{config.region_name}.amazonaws.com/openai/v1",
42+
)
8343
self._config = config
84-
self._client = None
85-
86-
@property
87-
def client(self) -> BaseClient:
88-
if self._client is None:
89-
self._client = create_bedrock_client(self._config)
90-
return self._client
9144

9245
async def initialize(self) -> None:
93-
pass
46+
await super().initialize()
9447

9548
async def shutdown(self) -> None:
96-
if self._client is not None:
97-
self._client.close()
98-
99-
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
100-
bedrock_model = request.model
101-
102-
sampling_params = request.sampling_params
103-
options = get_sampling_strategy_options(sampling_params)
104-
105-
if sampling_params.max_tokens:
106-
options["max_gen_len"] = sampling_params.max_tokens
107-
if sampling_params.repetition_penalty > 0:
108-
options["repetition_penalty"] = sampling_params.repetition_penalty
49+
await super().shutdown()
10950

110-
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
111-
112-
# Convert foundation model ID to inference profile ID
113-
region_name = self.client.meta.region_name
114-
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
115-
116-
return {
117-
"modelId": inference_profile_id,
118-
"body": json.dumps(
119-
{
120-
"prompt": prompt,
121-
**options,
122-
}
123-
),
124-
}
125-
126-
async def openai_embeddings(
127-
self,
128-
model: str,
129-
input: str | list[str],
130-
encoding_format: str | None = "float",
131-
dimensions: int | None = None,
132-
user: str | None = None,
133-
) -> OpenAIEmbeddingsResponse:
134-
raise NotImplementedError()
135-
136-
async def openai_completion(
137-
self,
138-
# Standard OpenAI completion parameters
139-
model: str,
140-
prompt: str | list[str] | list[int] | list[list[int]],
141-
best_of: int | None = None,
142-
echo: bool | None = None,
143-
frequency_penalty: float | None = None,
144-
logit_bias: dict[str, float] | None = None,
145-
logprobs: bool | None = None,
146-
max_tokens: int | None = None,
147-
n: int | None = None,
148-
presence_penalty: float | None = None,
149-
seed: int | None = None,
150-
stop: str | list[str] | None = None,
151-
stream: bool | None = None,
152-
stream_options: dict[str, Any] | None = None,
153-
temperature: float | None = None,
154-
top_p: float | None = None,
155-
user: str | None = None,
156-
# vLLM-specific parameters
157-
guided_choice: list[str] | None = None,
158-
prompt_logprobs: int | None = None,
159-
# for fill-in-the-middle type completion
160-
suffix: str | None = None,
161-
) -> OpenAICompletion:
162-
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
51+
def _get_openai_client(self) -> AsyncOpenAI:
52+
return AsyncOpenAI(
53+
base_url=self.api_base,
54+
api_key=self.get_api_key(),
55+
)
16356

16457
async def openai_chat_completion(
16558
self,
@@ -187,4 +80,62 @@ async def openai_chat_completion(
18780
top_p: float | None = None,
18881
user: str | None = None,
18982
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
190-
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
83+
assert self.model_store is not None
84+
model_obj = await self.model_store.get_model(model)
85+
86+
# Bedrock OpenAI-compatible endpoint expects base model IDs (e.g. "openai.gpt-oss-20b-1:0").
87+
# Cross-region inference profile IDs prefixed with "us." are not recognized by the endpoint.
88+
# Normalize to base model ID, then try both base and prefixed forms for compatibility.
89+
provider_model_id: str = model_obj.provider_resource_id or model
90+
base_model_id = provider_model_id[3:] if provider_model_id.startswith("us.") else provider_model_id
91+
candidate_models = [base_model_id, f"us.{base_model_id}"]
92+
93+
# Enable streaming usage metrics when telemetry is active
94+
if stream and get_current_span() is not None:
95+
if stream_options is None:
96+
stream_options = {"include_usage": True}
97+
elif "include_usage" not in stream_options:
98+
stream_options = {**stream_options, "include_usage": True}
99+
100+
last_error: Exception | None = None
101+
for candidate in candidate_models:
102+
params = await prepare_openai_completion_params(
103+
model=candidate,
104+
messages=messages,
105+
frequency_penalty=frequency_penalty,
106+
function_call=function_call,
107+
functions=functions,
108+
logit_bias=logit_bias,
109+
logprobs=logprobs,
110+
max_completion_tokens=max_completion_tokens,
111+
max_tokens=max_tokens,
112+
n=n,
113+
parallel_tool_calls=parallel_tool_calls,
114+
presence_penalty=presence_penalty,
115+
response_format=response_format,
116+
seed=seed,
117+
stop=stop,
118+
stream=stream,
119+
stream_options=stream_options,
120+
temperature=temperature,
121+
tool_choice=tool_choice,
122+
tools=tools,
123+
top_logprobs=top_logprobs,
124+
top_p=top_p,
125+
user=user,
126+
)
127+
try:
128+
_resp = await self._get_openai_client().chat.completions.create(**params)
129+
response = cast(OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk], _resp)
130+
return response
131+
except AuthenticationError as e:
132+
# Authentication errors - no retry with different model IDs
133+
# Raise immediately with proper error message
134+
raise ValueError(f"Authentication failed: {str(e)}") from e
135+
except (NotFoundError, BadRequestError) as e:
136+
last_error = e
137+
continue
138+
139+
if last_error:
140+
raise last_error
141+
raise RuntimeError("Bedrock chat completion failed")

llama_stack/providers/remote/inference/bedrock/config.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,33 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
7+
import os
88

9+
from pydantic import BaseModel, Field
910

10-
class BedrockConfig(BedrockBaseConfig):
11-
pass
11+
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
12+
13+
14+
class BedrockProviderDataValidator(BaseModel):
15+
aws_bedrock_api_key: str | None = Field(
16+
default=None,
17+
description="API key for Amazon Bedrock",
18+
)
19+
20+
21+
class BedrockConfig(RemoteInferenceProviderConfig):
22+
api_key: str | None = Field(
23+
default_factory=lambda: os.getenv("AWS_BEDROCK_API_KEY"),
24+
description="Amazon Bedrock API key",
25+
)
26+
region_name: str = Field(
27+
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION", "us-east-2"),
28+
description="AWS Region for the Bedrock Runtime endpoint",
29+
)
30+
31+
@classmethod
32+
def sample_run_config(cls, **kwargs):
33+
return {
34+
"api_key": "${env.AWS_BEDROCK_API_KEY:=}",
35+
"region_name": "${env.AWS_DEFAULT_REGION:=us-east-2}",
36+
}

0 commit comments

Comments
 (0)