Skip to content

Commit de3d5d1

Browse files
ahao-anyscalenrghoshangelinalg
authored andcommitted
[serve.llm] Score API Integration for Serve LLM (#55914)
Signed-off-by: ahao-anyscale <[email protected]> Co-authored-by: Nikhil G <[email protected]> Co-authored-by: angelinalg <[email protected]> Signed-off-by: Douglas Strodtman <[email protected]>
1 parent 58f5175 commit de3d5d1

File tree

10 files changed

+212
-4
lines changed

10 files changed

+212
-4
lines changed

doc/source/serve/llm/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ This deployment provides an OpenAI-compatible FastAPI ingress and routes traffic
3636
- `/v1/chat/completions`: Chat interface (ChatGPT-style)
3737
- `/v1/completions`: Text completion
3838
- `/v1/embeddings`: Text embeddings
39+
- `/v1/score`: Text comparison
3940
- `/v1/models`: List available models
4041
- `/v1/models/{model}`: Model information
4142

python/ray/llm/_internal/serve/configs/openai_api_models.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
EmbeddingCompletionRequest as vLLMEmbeddingCompletionRequest,
2323
EmbeddingResponse as vLLMEmbeddingResponse,
2424
ErrorResponse as vLLMErrorResponse,
25+
ScoreRequest as vLLMScoreRequest,
26+
ScoreResponse as vLLMScoreResponse,
2527
)
2628
from vllm.utils import random_uuid
2729

@@ -89,12 +91,24 @@ class EmbeddingResponse(vLLMEmbeddingResponse):
8991
model_config = ConfigDict(arbitrary_types_allowed=True)
9092

9193

94+
class ScoreRequest(vLLMScoreRequest):
95+
model_config = ConfigDict(arbitrary_types_allowed=True)
96+
97+
98+
class ScoreResponse(vLLMScoreResponse):
99+
model_config = ConfigDict(arbitrary_types_allowed=True)
100+
101+
92102
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
93103

94104
LLMEmbeddingsResponse = Union[
95105
AsyncGenerator[Union[EmbeddingResponse, ErrorResponse], None],
96106
]
97107

108+
LLMScoreResponse = Union[
109+
AsyncGenerator[Union[ScoreResponse, ErrorResponse], None],
110+
]
111+
98112
LLMChatResponse = Union[
99113
AsyncGenerator[Union[str, ChatCompletionResponse, ErrorResponse], None],
100114
]

python/ray/llm/_internal/serve/deployments/llm/llm_server.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
EmbeddingRequest,
5353
EmbeddingResponse,
5454
ErrorResponse,
55+
ScoreRequest,
56+
ScoreResponse,
5557
)
5658

5759
logger = get_logger(__name__)
@@ -306,7 +308,10 @@ def _batch_output_stream(
306308
async def _run_request(
307309
self,
308310
request: Union[
309-
"ChatCompletionRequest", "CompletionRequest", "EmbeddingRequest"
311+
"ChatCompletionRequest",
312+
"CompletionRequest",
313+
"EmbeddingRequest",
314+
"ScoreRequest",
310315
],
311316
*,
312317
engine_method: str,
@@ -392,6 +397,24 @@ async def embeddings(
392397
request, engine_method="embeddings", batch_output_stream=False
393398
)
394399

400+
async def score(
401+
self, request: "ScoreRequest"
402+
) -> AsyncGenerator[Union["ScoreResponse", "ErrorResponse"], None]:
403+
"""Runs a score request to the engine and returns the response.
404+
405+
Returns an AsyncGenerator over the ScoreResponse object. This is so that the caller can have a consistent interface across all the methods of chat, completions, embeddings, and score.
406+
407+
Args:
408+
request: A ScoreRequest object.
409+
410+
Returns:
411+
An AsyncGenerator over the ScoreResponse object.
412+
"""
413+
# NOTE: Score does not need batching, similar to embeddings.
414+
return await self._run_request(
415+
request, engine_method="score", batch_output_stream=False
416+
)
417+
395418
async def check_health(self) -> None:
396419
"""
397420
Check the health of the replica. Does not return anything. Raise error when

python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
EmbeddingRequest,
2020
EmbeddingResponse,
2121
ErrorResponse,
22+
ScoreRequest,
23+
ScoreResponse,
2224
)
2325
from ray.llm._internal.serve.configs.server_models import (
2426
DiskMultiplexConfig,
@@ -43,6 +45,7 @@
4345
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
4446
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
4547
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
48+
from vllm.entrypoints.openai.serving_score import ServingScores
4649

4750
vllm = try_import("vllm")
4851
logger = get_logger(__name__)
@@ -134,6 +137,7 @@ def __init__(
134137
self._oai_serving_chat: Optional["OpenAIServingChat"] = None
135138
self._oai_serving_completion: Optional["OpenAIServingCompletion"] = None
136139
self._oai_serving_embedding: Optional["OpenAIServingEmbedding"] = None
140+
self._oai_serving_scores: Optional["ServingScores"] = None
137141

138142
async def start(self) -> None:
139143
"""Start the vLLM engine.
@@ -189,6 +193,7 @@ async def start(self) -> None:
189193
self._oai_serving_chat = state.openai_serving_chat
190194
self._oai_serving_completion = state.openai_serving_completion
191195
self._oai_serving_embedding = state.openai_serving_embedding
196+
self._oai_serving_scores = state.openai_serving_scores
192197

193198
self._validate_openai_serving_models()
194199
self._validate_engine_client()
@@ -221,6 +226,11 @@ def _validate_openai_serving_embedding(self):
221226
self._oai_serving_embedding, "create_embedding"
222227
), "oai_serving_embedding must have a create_embedding attribute"
223228

229+
def _validate_openai_serving_scores(self):
230+
assert hasattr(
231+
self._oai_serving_scores, "create_score"
232+
), "oai_serving_scores must have a create_score attribute"
233+
224234
def _validate_engine_client(self):
225235
assert hasattr(
226236
self._engine_client, "check_health"
@@ -354,7 +364,9 @@ async def resolve_lora(self, disk_lora_model: DiskMultiplexConfig):
354364

355365
def _create_raw_request(
356366
self,
357-
request: Union[CompletionRequest, ChatCompletionRequest, EmbeddingRequest],
367+
request: Union[
368+
CompletionRequest, ChatCompletionRequest, EmbeddingRequest, ScoreRequest
369+
],
358370
path: str,
359371
) -> Request:
360372
scope = {
@@ -442,6 +454,22 @@ async def embeddings(
442454
else:
443455
yield EmbeddingResponse(**embedding_response.model_dump())
444456

457+
async def score(
458+
self, request: ScoreRequest
459+
) -> AsyncGenerator[Union[ScoreResponse, ErrorResponse], None]:
460+
self._validate_openai_serving_scores()
461+
462+
raw_request = self._create_raw_request(request, "/score")
463+
464+
score_response = await self._oai_serving_scores.create_score(
465+
request, raw_request=raw_request
466+
)
467+
468+
if isinstance(score_response, VLLMErrorResponse):
469+
yield ErrorResponse(**score_response.model_dump())
470+
else:
471+
yield ScoreResponse(**score_response.model_dump())
472+
445473
async def check_health(self) -> None:
446474
assert self._engine_client is not None, "engine_client is not initialized"
447475

python/ray/llm/_internal/serve/deployments/routers/router.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@
4646
LLMChatResponse,
4747
LLMCompletionsResponse,
4848
LLMEmbeddingsResponse,
49+
LLMScoreResponse,
4950
ModelCard,
5051
ModelList,
5152
OpenAIHTTPException,
53+
ScoreRequest,
54+
ScoreResponse,
5255
to_model_metadata,
5356
)
5457
from ray.llm._internal.serve.configs.server_models import LLMConfig
@@ -310,10 +313,18 @@ def _get_configured_serve_handle(self, model_id: str):
310313
async def _get_response(
311314
self,
312315
*,
313-
body: Union[CompletionRequest, ChatCompletionRequest, EmbeddingRequest],
316+
body: Union[
317+
CompletionRequest, ChatCompletionRequest, EmbeddingRequest, ScoreRequest
318+
],
314319
call_method: str,
315320
) -> AsyncGenerator[
316-
Union[LLMChatResponse, LLMCompletionsResponse, LLMEmbeddingsResponse], None
321+
Union[
322+
LLMChatResponse,
323+
LLMCompletionsResponse,
324+
LLMEmbeddingsResponse,
325+
LLMScoreResponse,
326+
],
327+
None,
317328
]:
318329
"""Calls the model deployment and returns the stream."""
319330
model: str = body.model
@@ -478,6 +489,32 @@ async def embeddings(self, body: EmbeddingRequest) -> Response:
478489
if isinstance(result, EmbeddingResponse):
479490
return JSONResponse(content=result.model_dump())
480491

492+
@fastapi_router_app.post("/v1/score")
493+
async def score(self, body: ScoreRequest) -> Response:
494+
"""Create scores for the provided text pairs.
495+
496+
Note: This is a vLLM specific endpoint.
497+
498+
Args:
499+
body: The score request containing input text pairs to score.
500+
501+
Returns:
502+
A response object with scores.
503+
"""
504+
505+
async with timeout(DEFAULT_LLM_ROUTER_HTTP_TIMEOUT):
506+
results = self._get_response(body=body, call_method="score")
507+
result = await results.__anext__()
508+
if isinstance(result, ErrorResponse):
509+
raise OpenAIHTTPException(
510+
message=result.message,
511+
status_code=result.code,
512+
type=result.type,
513+
)
514+
515+
if isinstance(result, ScoreResponse):
516+
return JSONResponse(content=result.model_dump())
517+
481518
@classmethod
482519
def as_deployment(
483520
cls, llm_configs: Optional[List[LLMConfig]] = None

python/ray/llm/tests/serve/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ChatCompletionRequest,
1616
CompletionRequest,
1717
EmbeddingCompletionRequest,
18+
ScoreRequest,
1819
)
1920
from ray.llm._internal.serve.deployments.llm.vllm.vllm_models import (
2021
VLLMEngineConfig,
@@ -112,6 +113,16 @@ def mock_embedding_request(dimensions):
112113
return request
113114

114115

116+
@pytest.fixture
117+
def mock_score_request():
118+
"""Fixture for creating score requests for mock testing."""
119+
return ScoreRequest(
120+
model=MOCK_MODEL_ID,
121+
text_1="What is the capital of France?",
122+
text_2="The capital of France is Paris.",
123+
)
124+
125+
115126
def get_test_model_path(yaml_file: str) -> pathlib.Path:
116127
current_file_dir = pathlib.Path(__file__).absolute().parent
117128
test_model_path = current_file_dir / yaml_file

python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,18 @@ async def test_embedding_mock_engine(
8181

8282
async for response in engine.embeddings(request):
8383
LLMResponseValidator.validate_embedding_response(response, dimensions)
84+
85+
@pytest.mark.asyncio
86+
async def test_score_mock_engine(self, mock_llm_config, mock_score_request):
87+
"""Test score API for text similarity."""
88+
# Create and start the engine
89+
engine = MockVLLMEngine(mock_llm_config)
90+
await engine.start()
91+
92+
# Create score request
93+
request = mock_score_request
94+
95+
print("\n\n_____ SCORE _____\n\n")
96+
97+
async for response in engine.score(request):
98+
LLMResponseValidator.validate_score_response(response)

python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_server.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,34 @@ async def test_embedding_llm_server(
152152
# Validate embedding response
153153
LLMResponseValidator.validate_embedding_response(chunks[0], dimensions)
154154

155+
@pytest.mark.asyncio
156+
async def test_score_llm_server(
157+
self,
158+
serve_handle,
159+
mock_llm_config,
160+
mock_score_request,
161+
):
162+
"""Test score API from LLMServer perspective."""
163+
164+
# Create score request
165+
request = mock_score_request
166+
167+
print("\n\n_____ SCORE SERVER _____\n\n")
168+
169+
# Get the response
170+
batched_chunks = serve_handle.score.remote(request)
171+
172+
# Collect responses (should be just one)
173+
chunks = []
174+
async for batch in batched_chunks:
175+
chunks.append(batch)
176+
177+
# Check that we got one response
178+
assert len(chunks) == 1
179+
180+
# Validate score response
181+
LLMResponseValidator.validate_score_response(chunks[0])
182+
155183
@pytest.mark.asyncio
156184
async def test_check_health(self, mock_llm_config):
157185
"""Test health check functionality."""

python/ray/llm/tests/serve/mocks/mock_vllm_engine.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
EmbeddingRequest,
1414
EmbeddingResponse,
1515
ErrorResponse,
16+
ScoreRequest,
17+
ScoreResponse,
1618
)
1719
from ray.llm._internal.serve.configs.server_models import (
1820
DiskMultiplexConfig,
@@ -135,6 +137,41 @@ async def embeddings(
135137
)
136138
yield response
137139

140+
async def score(
141+
self, request: ScoreRequest
142+
) -> AsyncGenerator[Union[str, ScoreResponse, ErrorResponse], None]:
143+
"""Mock score generation for text pairs."""
144+
if not self.started:
145+
raise RuntimeError("Engine not started")
146+
147+
# Extract text_1 and text_2 from the request
148+
text_1 = getattr(request, "text_1", "")
149+
text_2 = getattr(request, "text_2", "")
150+
151+
# Convert to lists if they aren't already
152+
text_1_list = text_1 if isinstance(text_1, list) else [text_1]
153+
text_2_list = text_2 if isinstance(text_2, list) else [text_2]
154+
155+
# Generate mock scores for each pair
156+
score_data = []
157+
for i, (t1, t2) in enumerate(zip(text_1_list, text_2_list)):
158+
# Generate a random score (can be any float value)
159+
score = random.uniform(-10.0, 10.0)
160+
161+
score_data.append({"object": "score", "score": score, "index": i})
162+
163+
# Create the response
164+
response = ScoreResponse(
165+
object="list",
166+
data=score_data,
167+
model=getattr(request, "model", "mock-model"),
168+
usage={
169+
"prompt_tokens": len(str(text_1).split()) + len(str(text_2).split()),
170+
"total_tokens": len(str(text_1).split()) + len(str(text_2).split()),
171+
},
172+
)
173+
yield response
174+
138175
async def _generate_chat_response(
139176
self, request: ChatCompletionRequest, prompt_text: str, max_tokens: int
140177
) -> AsyncGenerator[Union[str, ChatCompletionResponse], None]:

python/ray/llm/tests/serve/utils/testing_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ChatCompletionResponse,
1212
CompletionResponse,
1313
EmbeddingResponse,
14+
ScoreResponse,
1415
)
1516

1617

@@ -94,3 +95,16 @@ def validate_embedding_response(
9495
# Check dimensions if specified
9596
if expected_dimensions:
9697
assert len(response.data[0].embedding) == expected_dimensions
98+
99+
@staticmethod
100+
def validate_score_response(response: ScoreResponse):
101+
"""Validate score responses."""
102+
assert isinstance(response, ScoreResponse)
103+
assert response.object == "list"
104+
assert len(response.data) >= 1
105+
106+
# Validate each score data element
107+
for i, score_data in enumerate(response.data):
108+
assert score_data.object == "score"
109+
assert isinstance(score_data.score, float)
110+
assert score_data.index == i # Index should match position in list

0 commit comments

Comments
 (0)