Skip to content

Commit 55e9959

Browse files
fix: fix openai_embeddings for asymmetric embedding NIMs (#3205)
# What does this PR do? NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. This PR adds the `input_type="query"` as default and updates the documentation to suggest using the `embedding` API for passage embeddings. <!-- If resolving an issue, uncomment and update the line below --> Resolves #2892 ## Test Plan ``` pytest -s -v tests/integration/inference/test_openai_embeddings.py --stack-config="inference=nvidia" --embedding-model="nvidia/llama-3.2-nv-embedqa-1b-v2" --env NVIDIA_API_KEY={nvidia_api_key} --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com" ```
1 parent 3f8df16 commit 55e9959

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

llama_stack/providers/remote/inference/nvidia/NVIDIA.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ print(f"Response: {response.completion_message.content}")
7777
```
7878

7979
### Create Embeddings
80+
> Note on OpenAI embeddings compatibility
81+
>
82+
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
83+
8084
```python
8185
response = client.inference.embeddings(
8286
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",

llama_stack/providers/remote/inference/nvidia/nvidia.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import warnings
88
from collections.abc import AsyncIterator
99

10-
from openai import APIConnectionError, BadRequestError
10+
from openai import NOT_GIVEN, APIConnectionError, BadRequestError
1111

1212
from llama_stack.apis.common.content_types import (
1313
InterleavedContent,
@@ -26,6 +26,9 @@
2626
Inference,
2727
LogProbConfig,
2828
Message,
29+
OpenAIEmbeddingData,
30+
OpenAIEmbeddingsResponse,
31+
OpenAIEmbeddingUsage,
2932
ResponseFormat,
3033
SamplingParams,
3134
TextTruncation,
@@ -210,6 +213,57 @@ async def embeddings(
210213
#
211214
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
212215

216+
async def openai_embeddings(
217+
self,
218+
model: str,
219+
input: str | list[str],
220+
encoding_format: str | None = "float",
221+
dimensions: int | None = None,
222+
user: str | None = None,
223+
) -> OpenAIEmbeddingsResponse:
224+
"""
225+
OpenAI-compatible embeddings for NVIDIA NIM.
226+
227+
Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API.
228+
We default this to "query" to ensure requests succeed when using the
229+
OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with
230+
`task_type='document'`.
231+
"""
232+
extra_body: dict[str, object] = {"input_type": "query"}
233+
logger.warning(
234+
"NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. "
235+
"For passage embeddings, use the embeddings API with task_type='document'."
236+
)
237+
238+
response = await self.client.embeddings.create(
239+
model=await self._get_provider_model_id(model),
240+
input=input,
241+
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
242+
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
243+
user=user if user is not None else NOT_GIVEN,
244+
extra_body=extra_body,
245+
)
246+
247+
data = []
248+
for i, embedding_data in enumerate(response.data):
249+
data.append(
250+
OpenAIEmbeddingData(
251+
embedding=embedding_data.embedding,
252+
index=i,
253+
)
254+
)
255+
256+
usage = OpenAIEmbeddingUsage(
257+
prompt_tokens=response.usage.prompt_tokens,
258+
total_tokens=response.usage.total_tokens,
259+
)
260+
261+
return OpenAIEmbeddingsResponse(
262+
data=data,
263+
model=response.model,
264+
usage=usage,
265+
)
266+
213267
async def chat_completion(
214268
self,
215269
model_id: str,

0 commit comments

Comments
 (0)