|
7 | 7 | import warnings
|
8 | 8 | from collections.abc import AsyncIterator
|
9 | 9 |
|
10 |
| -from openai import APIConnectionError, BadRequestError |
| 10 | +from openai import NOT_GIVEN, APIConnectionError, BadRequestError |
11 | 11 |
|
12 | 12 | from llama_stack.apis.common.content_types import (
|
13 | 13 | InterleavedContent,
|
|
26 | 26 | Inference,
|
27 | 27 | LogProbConfig,
|
28 | 28 | Message,
|
| 29 | + OpenAIEmbeddingData, |
| 30 | + OpenAIEmbeddingsResponse, |
| 31 | + OpenAIEmbeddingUsage, |
29 | 32 | ResponseFormat,
|
30 | 33 | SamplingParams,
|
31 | 34 | TextTruncation,
|
@@ -210,6 +213,57 @@ async def embeddings(
|
210 | 213 | #
|
211 | 214 | return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
212 | 215 |
|
| 216 | + async def openai_embeddings( |
| 217 | + self, |
| 218 | + model: str, |
| 219 | + input: str | list[str], |
| 220 | + encoding_format: str | None = "float", |
| 221 | + dimensions: int | None = None, |
| 222 | + user: str | None = None, |
| 223 | + ) -> OpenAIEmbeddingsResponse: |
| 224 | + """ |
| 225 | + OpenAI-compatible embeddings for NVIDIA NIM. |
| 226 | +
|
| 227 | + Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API. |
| 228 | + We default this to "query" to ensure requests succeed when using the |
| 229 | + OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with |
| 230 | + `task_type='document'`. |
| 231 | + """ |
| 232 | + extra_body: dict[str, object] = {"input_type": "query"} |
| 233 | + logger.warning( |
| 234 | + "NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. " |
| 235 | + "For passage embeddings, use the embeddings API with task_type='document'." |
| 236 | + ) |
| 237 | + |
| 238 | + response = await self.client.embeddings.create( |
| 239 | + model=await self._get_provider_model_id(model), |
| 240 | + input=input, |
| 241 | + encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN, |
| 242 | + dimensions=dimensions if dimensions is not None else NOT_GIVEN, |
| 243 | + user=user if user is not None else NOT_GIVEN, |
| 244 | + extra_body=extra_body, |
| 245 | + ) |
| 246 | + |
| 247 | + data = [] |
| 248 | + for i, embedding_data in enumerate(response.data): |
| 249 | + data.append( |
| 250 | + OpenAIEmbeddingData( |
| 251 | + embedding=embedding_data.embedding, |
| 252 | + index=i, |
| 253 | + ) |
| 254 | + ) |
| 255 | + |
| 256 | + usage = OpenAIEmbeddingUsage( |
| 257 | + prompt_tokens=response.usage.prompt_tokens, |
| 258 | + total_tokens=response.usage.total_tokens, |
| 259 | + ) |
| 260 | + |
| 261 | + return OpenAIEmbeddingsResponse( |
| 262 | + data=data, |
| 263 | + model=response.model, |
| 264 | + usage=usage, |
| 265 | + ) |
| 266 | + |
213 | 267 | async def chat_completion(
|
214 | 268 | self,
|
215 | 269 | model_id: str,
|
|
0 commit comments