|
4 | 4 | # This source code is licensed under the terms described in the LICENSE file in |
5 | 5 | # the root directory of this source tree. |
6 | 6 |
|
7 | | -import json |
8 | 7 | from collections.abc import AsyncIterator |
9 | | -from typing import Any |
| 8 | +from typing import Any, cast |
10 | 9 |
|
11 | | -from botocore.client import BaseClient |
| 10 | +from openai import AsyncOpenAI, AuthenticationError, BadRequestError, NotFoundError |
12 | 11 |
|
13 | 12 | from llama_stack.apis.inference import ( |
14 | | - ChatCompletionRequest, |
15 | | - Inference, |
16 | | - OpenAIEmbeddingsResponse, |
17 | | -) |
18 | | -from llama_stack.apis.inference.inference import ( |
19 | 13 | OpenAIChatCompletion, |
20 | 14 | OpenAIChatCompletionChunk, |
21 | | - OpenAICompletion, |
22 | 15 | OpenAIMessageParam, |
23 | 16 | OpenAIResponseFormatParam, |
24 | 17 | ) |
| 18 | +from llama_stack.log import get_logger |
25 | 19 | from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig |
26 | | -from llama_stack.providers.utils.bedrock.client import create_bedrock_client |
27 | | -from llama_stack.providers.utils.inference.model_registry import ( |
28 | | - ModelRegistryHelper, |
29 | | -) |
| 20 | +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin |
30 | 21 | from llama_stack.providers.utils.inference.openai_compat import ( |
31 | | - get_sampling_strategy_options, |
32 | | -) |
33 | | -from llama_stack.providers.utils.inference.prompt_adapter import ( |
34 | | - chat_completion_request_to_prompt, |
| 22 | + prepare_openai_completion_params, |
35 | 23 | ) |
| 24 | +from llama_stack.providers.utils.telemetry.tracing import get_current_span |
36 | 25 |
|
37 | 26 | from .models import MODEL_ENTRIES |
38 | 27 |
|
39 | | -REGION_PREFIX_MAP = { |
40 | | - "us": "us.", |
41 | | - "eu": "eu.", |
42 | | - "ap": "ap.", |
43 | | -} |
44 | | - |
45 | | - |
46 | | -def _get_region_prefix(region: str | None) -> str: |
47 | | - # AWS requires region prefixes for inference profiles |
48 | | - if region is None: |
49 | | - return "us." # default to US when we don't know |
50 | | - |
51 | | - # Handle case insensitive region matching |
52 | | - region_lower = region.lower() |
53 | | - for prefix in REGION_PREFIX_MAP: |
54 | | - if region_lower.startswith(f"{prefix}-"): |
55 | | - return REGION_PREFIX_MAP[prefix] |
56 | | - |
57 | | - # Fallback to US for anything we don't recognize |
58 | | - return "us." |
59 | | - |
60 | | - |
61 | | -def _to_inference_profile_id(model_id: str, region: str = None) -> str: |
62 | | - # Return ARNs unchanged |
63 | | - if model_id.startswith("arn:"): |
64 | | - return model_id |
65 | | - |
66 | | - # Return inference profile IDs that already have regional prefixes |
67 | | - if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()): |
68 | | - return model_id |
| 28 | +logger = get_logger(__name__) |
69 | 29 |
|
70 | | - # Default to US East when no region is provided |
71 | | - if region is None: |
72 | | - region = "us-east-1" |
73 | 30 |
|
74 | | - return _get_region_prefix(region) + model_id |
| 31 | +class BedrockInferenceAdapter(LiteLLMOpenAIMixin): |
| 32 | + _config: BedrockConfig |
75 | 33 |
|
76 | | - |
77 | | -class BedrockInferenceAdapter( |
78 | | - ModelRegistryHelper, |
79 | | - Inference, |
80 | | -): |
81 | 34 | def __init__(self, config: BedrockConfig) -> None: |
82 | | - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) |
| 35 | + LiteLLMOpenAIMixin.__init__( |
| 36 | + self, |
| 37 | + model_entries=MODEL_ENTRIES, |
| 38 | + litellm_provider_name="openai", |
| 39 | + api_key_from_config=config.api_key, |
| 40 | + provider_data_api_key_field="aws_bedrock_api_key", # Fixed: Match validator field name |
| 41 | + openai_compat_api_base=f"https://bedrock-runtime.{config.region_name}.amazonaws.com/openai/v1", |
| 42 | + ) |
83 | 43 | self._config = config |
84 | | - self._client = None |
85 | | - |
86 | | - @property |
87 | | - def client(self) -> BaseClient: |
88 | | - if self._client is None: |
89 | | - self._client = create_bedrock_client(self._config) |
90 | | - return self._client |
91 | 44 |
|
92 | 45 | async def initialize(self) -> None: |
93 | | - pass |
| 46 | + await super().initialize() |
94 | 47 |
|
95 | 48 | async def shutdown(self) -> None: |
96 | | - if self._client is not None: |
97 | | - self._client.close() |
98 | | - |
99 | | - async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict: |
100 | | - bedrock_model = request.model |
101 | | - |
102 | | - sampling_params = request.sampling_params |
103 | | - options = get_sampling_strategy_options(sampling_params) |
104 | | - |
105 | | - if sampling_params.max_tokens: |
106 | | - options["max_gen_len"] = sampling_params.max_tokens |
107 | | - if sampling_params.repetition_penalty > 0: |
108 | | - options["repetition_penalty"] = sampling_params.repetition_penalty |
| 49 | + await super().shutdown() |
109 | 50 |
|
110 | | - prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) |
111 | | - |
112 | | - # Convert foundation model ID to inference profile ID |
113 | | - region_name = self.client.meta.region_name |
114 | | - inference_profile_id = _to_inference_profile_id(bedrock_model, region_name) |
115 | | - |
116 | | - return { |
117 | | - "modelId": inference_profile_id, |
118 | | - "body": json.dumps( |
119 | | - { |
120 | | - "prompt": prompt, |
121 | | - **options, |
122 | | - } |
123 | | - ), |
124 | | - } |
125 | | - |
126 | | - async def openai_embeddings( |
127 | | - self, |
128 | | - model: str, |
129 | | - input: str | list[str], |
130 | | - encoding_format: str | None = "float", |
131 | | - dimensions: int | None = None, |
132 | | - user: str | None = None, |
133 | | - ) -> OpenAIEmbeddingsResponse: |
134 | | - raise NotImplementedError() |
135 | | - |
136 | | - async def openai_completion( |
137 | | - self, |
138 | | - # Standard OpenAI completion parameters |
139 | | - model: str, |
140 | | - prompt: str | list[str] | list[int] | list[list[int]], |
141 | | - best_of: int | None = None, |
142 | | - echo: bool | None = None, |
143 | | - frequency_penalty: float | None = None, |
144 | | - logit_bias: dict[str, float] | None = None, |
145 | | - logprobs: bool | None = None, |
146 | | - max_tokens: int | None = None, |
147 | | - n: int | None = None, |
148 | | - presence_penalty: float | None = None, |
149 | | - seed: int | None = None, |
150 | | - stop: str | list[str] | None = None, |
151 | | - stream: bool | None = None, |
152 | | - stream_options: dict[str, Any] | None = None, |
153 | | - temperature: float | None = None, |
154 | | - top_p: float | None = None, |
155 | | - user: str | None = None, |
156 | | - # vLLM-specific parameters |
157 | | - guided_choice: list[str] | None = None, |
158 | | - prompt_logprobs: int | None = None, |
159 | | - # for fill-in-the-middle type completion |
160 | | - suffix: str | None = None, |
161 | | - ) -> OpenAICompletion: |
162 | | - raise NotImplementedError("OpenAI completion not supported by the Bedrock provider") |
| 51 | + def _get_openai_client(self) -> AsyncOpenAI: |
| 52 | + return AsyncOpenAI( |
| 53 | + base_url=self.api_base, |
| 54 | + api_key=self.get_api_key(), |
| 55 | + ) |
163 | 56 |
|
164 | 57 | async def openai_chat_completion( |
165 | 58 | self, |
@@ -187,4 +80,62 @@ async def openai_chat_completion( |
187 | 80 | top_p: float | None = None, |
188 | 81 | user: str | None = None, |
189 | 82 | ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: |
190 | | - raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider") |
| 83 | + assert self.model_store is not None |
| 84 | + model_obj = await self.model_store.get_model(model) |
| 85 | + |
| 86 | + # Bedrock OpenAI-compatible endpoint expects base model IDs (e.g. "openai.gpt-oss-20b-1:0"). |
| 87 | + # Cross-region inference profile IDs prefixed with "us." are not recognized by the endpoint. |
| 88 | + # Normalize to base model ID, then try both base and prefixed forms for compatibility. |
| 89 | + provider_model_id: str = model_obj.provider_resource_id or model |
| 90 | + base_model_id = provider_model_id[3:] if provider_model_id.startswith("us.") else provider_model_id |
| 91 | + candidate_models = [base_model_id, f"us.{base_model_id}"] |
| 92 | + |
| 93 | + # Enable streaming usage metrics when telemetry is active |
| 94 | + if stream and get_current_span() is not None: |
| 95 | + if stream_options is None: |
| 96 | + stream_options = {"include_usage": True} |
| 97 | + elif "include_usage" not in stream_options: |
| 98 | + stream_options = {**stream_options, "include_usage": True} |
| 99 | + |
| 100 | + last_error: Exception | None = None |
| 101 | + for candidate in candidate_models: |
| 102 | + params = await prepare_openai_completion_params( |
| 103 | + model=candidate, |
| 104 | + messages=messages, |
| 105 | + frequency_penalty=frequency_penalty, |
| 106 | + function_call=function_call, |
| 107 | + functions=functions, |
| 108 | + logit_bias=logit_bias, |
| 109 | + logprobs=logprobs, |
| 110 | + max_completion_tokens=max_completion_tokens, |
| 111 | + max_tokens=max_tokens, |
| 112 | + n=n, |
| 113 | + parallel_tool_calls=parallel_tool_calls, |
| 114 | + presence_penalty=presence_penalty, |
| 115 | + response_format=response_format, |
| 116 | + seed=seed, |
| 117 | + stop=stop, |
| 118 | + stream=stream, |
| 119 | + stream_options=stream_options, |
| 120 | + temperature=temperature, |
| 121 | + tool_choice=tool_choice, |
| 122 | + tools=tools, |
| 123 | + top_logprobs=top_logprobs, |
| 124 | + top_p=top_p, |
| 125 | + user=user, |
| 126 | + ) |
| 127 | + try: |
| 128 | + _resp = await self._get_openai_client().chat.completions.create(**params) |
| 129 | + response = cast(OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk], _resp) |
| 130 | + return response |
| 131 | + except AuthenticationError as e: |
| 132 | + # Authentication errors - no retry with different model IDs |
| 133 | + # Raise immediately with proper error message |
| 134 | + raise ValueError(f"Authentication failed: {str(e)}") from e |
| 135 | + except (NotFoundError, BadRequestError) as e: |
| 136 | + last_error = e |
| 137 | + continue |
| 138 | + |
| 139 | + if last_error: |
| 140 | + raise last_error |
| 141 | + raise RuntimeError("Bedrock chat completion failed") |
0 commit comments