From 56d35475b89bb0ca1b03c0c2ec7ecce3a4185110 Mon Sep 17 00:00:00 2001 From: JunyiXu-nv Date: Tue, 2 Sep 2025 19:08:22 +0800 Subject: [PATCH 1/3] [TRTLLM-7208][feat] Implement basic functionalities for Responses API (#7341) Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com> --- tensorrt_llm/serve/harmony_adapter.py | 81 +- tensorrt_llm/serve/openai_protocol.py | 211 ++++- tensorrt_llm/serve/openai_server.py | 107 ++- tensorrt_llm/serve/responses_utils.py | 848 ++++++++++++++++++ tests/integration/defs/test_e2e.py | 7 + .../test_lists/test-db/l0_h100.yml | 1 + .../llmapi/apps/_test_openai_responses.py | 241 +++++ 7 files changed, 1493 insertions(+), 3 deletions(-) create mode 100644 tensorrt_llm/serve/responses_utils.py create mode 100644 tests/unittest/llmapi/apps/_test_openai_responses.py diff --git a/tensorrt_llm/serve/harmony_adapter.py b/tensorrt_llm/serve/harmony_adapter.py index c4703873de9..a46e7c5ed45 100644 --- a/tensorrt_llm/serve/harmony_adapter.py +++ b/tensorrt_llm/serve/harmony_adapter.py @@ -57,7 +57,8 @@ def __init__(self, # Normal case: filter based on available tools self.should_filter_tools = True self.available_tools = { - tool.get("function", {}).get("name", "") + tool.get("function", {}).get("name", "") if tool.get( + "name", None) is None else tool.get("name") for tool in available_tools } self.available_tools.discard("") @@ -78,6 +79,9 @@ def __init__(self, logger.debug("Created HarmonyStreamState for request %s", request_id) + def get_parser(self) -> StreamableParser: + return self.parser + def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]: """ Process a batch of tokens while maintaining parsing state. @@ -125,6 +129,42 @@ def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]: return deltas + def process_token_batch_to_messages(self, + tokens: list[int]) -> list[Message]: + """ + Process a batch of tokens while maintaining parsing state. + Returns OpenAI Messages for Responses API + """ + self.tokens_processed += len(tokens) + + for token in tokens: + # Store previous state for transition detection + prev_channel = self.parser.current_channel + prev_recipient = self.parser.current_recipient + + # Process the token + self.parser.process(token) + + # Detect channel/recipient transitions AFTER processing each token + channel_changed = prev_channel != self.parser.current_channel + recipient_changed = prev_recipient != self.parser.current_recipient + + if channel_changed or recipient_changed: + # Mark any active tool calls as completed if we're leaving a tool call + if prev_channel == "commentary" and prev_recipient and "functions." in str( + prev_recipient): + func_name = str(prev_recipient).split("functions.")[-1] + for tool_id, tool_info in self.tool_calls.items(): + if tool_info["name"] == func_name and tool_info.get( + "active", True): + tool_info["active"] = False + + # Reset channel state for new channel + self.channel_started = False + self.current_channel_state = None + + return self.parser.messages + def _create_closing_token_delta(self) -> dict[str, Any] | None: """Create closing token delta for channel transition.""" if not self.current_channel_state or not self.channel_started: @@ -317,6 +357,9 @@ def __init__( "<|constrain|>": 200009, } + def get_stream_state(self, request_id: str) -> HarmonyStreamState | None: + return self._stream_states.get(request_id, None) + def get_stop_tokens(self) -> list[int]: """ Return the list of stop token IDs for Harmony format. @@ -1214,6 +1257,42 @@ def stateful_stream_harmony_tokens_to_openai_deltas( # Return empty deltas to continue processing return [] + def stateful_stream_harmony_tokens_to_openai_messages( + self, + request_id: str, + tokens: list[int], + available_tools: list[dict[str, Any]] | None = None, + tool_choice: str | None = None) -> list[Message]: + """ + Process tokens using stateful parsing. + + This method maintains persistent state across multiple calls for the same request, + ensuring proper channel transitions and tool call handling. + + Args: + request_id: Request ID to maintain state per request + tokens: New tokens from this iteration + available_tools: Available tools for filtering + + Returns: + List of OpenAI Messages + """ + stream_state = self._stream_states.get(request_id, None) + if stream_state is None: + stream_state = self.create_stream_state(request_id, available_tools, + tool_choice) + + try: + messages = stream_state.process_token_batch_to_messages(tokens) + return messages + except (HarmonyError, UnicodeDecodeError, ValueError): + logger.error( + f"Streaming: Failed to process token batch of {len(tokens)} tokens for request {request_id}", + ) + logger.debug(f"Problematic streaming tokens: {tokens}") + + return [] + def create_openai_streaming_response( self, request_id: str, diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index f02178bacf5..acfbff14d23 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -11,9 +11,16 @@ ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam from openai.types.chat import \ ChatCompletionMessageParam as OpenAIChatCompletionMessageParam +from openai.types.responses import (ResponseFunctionToolCall, + ResponseInputItemParam, ResponseOutputItem, + ResponsePrompt, ResponseReasoningItem, + ResponseStatus, ResponseTextConfig) +from openai.types.responses.response import ToolChoice +from openai.types.responses.tool import Tool +from openai.types.shared import Metadata, Reasoning from openai_harmony import ReasoningEffort from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Annotated, Required, TypedDict +from typing_extensions import Annotated, Required, TypeAlias, TypedDict from tensorrt_llm.executor.request import LoRARequest from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams @@ -665,6 +672,208 @@ def check_suffix(cls, data): return data +ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam, + ResponseReasoningItem, + ResponseFunctionToolCall] + + +class ResponsesRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/responses/create + background: Optional[bool] = False + include: Optional[list[ + Literal[ + "code_interpreter_call.outputs", + "computer_call_output.output.image_url", + "file_search_call.results", + "message.input_image.image_url", + "message.output_text.logprobs", + "reasoning.encrypted_content", + ], + ]] = None + input: Union[str, list[ResponseInputOutputItem]] + instructions: Optional[str] = None + max_output_tokens: Optional[int] = None + max_tool_calls: Optional[int] = None + metadata: Optional[Metadata] = None + model: str + parallel_tool_calls: Optional[bool] = False + previous_response_id: Optional[str] = None + prompt: Optional[ResponsePrompt] = None + reasoning: Optional[Reasoning] = None + service_tier: Literal["auto", "default", "flex", "scale", + "priority"] = "auto" + store: Optional[bool] = True + stream: Optional[bool] = False + temperature: Optional[float] = None + text: Optional[ResponseTextConfig] = None + tool_choice: ToolChoice = "auto" + tools: list[Tool] = Field(default_factory=list) + top_logprobs: Optional[int] = 0 + top_p: Optional[float] = None + truncation: Optional[Literal["auto", "disabled"]] = "disabled" + user: Optional[str] = None + + request_id: str = Field( + default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response."), + ) + + _DEFAULT_SAMPLING_PARAMS = { + "temperature": 1.0, + "top_p": 1.0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None, + ) -> SamplingParams: + if self.max_output_tokens is None: + max_tokens = default_max_tokens + else: + max_tokens = min(self.max_output_tokens, default_max_tokens) + + default_sampling_params = default_sampling_params or {} + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + stop_token_ids = default_sampling_params.get("stop_token_ids") + + # Structured output + guided_decoding = None + if self.text is not None and self.text.format is not None: + response_format = self.text.format + if response_format.type == "json_schema": + guided_decoding = GuidedDecodingParams( + json=response_format.schema_) + elif response_format.type == "json_object": + raise NotImplementedError("json_object is not supported") + + return SamplingParams( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + logprobs=self.top_logprobs, + stop_token_ids=stop_token_ids, + guided_decoding=guided_decoding, + ) + + @model_validator(mode="before") + @classmethod + def validate_background(cls, data): + if not data.get("background"): + return data + if not data.get("store", True): + raise ValueError("background can only be used when `store` is true") + return data + + @model_validator(mode="before") + @classmethod + def validate_prompt(cls, data): + if data.get("prompt") is not None: + raise ValueError("prompt template is not supported") + return data + + +class InputTokensDetails(OpenAIBaseModel): + cached_tokens: int + + +class OutputTokensDetails(OpenAIBaseModel): + reasoning_tokens: int + + +class ResponseUsage(OpenAIBaseModel): + input_tokens: int + input_tokens_details: InputTokensDetails + output_tokens: int + output_tokens_details: OutputTokensDetails + total_tokens: int + + +class ResponsesResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}") + created_at: int = Field(default_factory=lambda: int(time.time())) + # error: Optional[ResponseError] = None + # incomplete_details: Optional[IncompleteDetails] = None + instructions: Optional[str] = None + metadata: Optional[Metadata] = None + model: str + object: Literal["response"] = "response" + output: list[ResponseOutputItem] + parallel_tool_calls: bool + temperature: float + tool_choice: ToolChoice + tools: list[Tool] + top_p: float + background: bool + max_output_tokens: int + max_tool_calls: Optional[int] = None + previous_response_id: Optional[str] = None + prompt: Optional[ResponsePrompt] = None + reasoning: Optional[Reasoning] = None + service_tier: Literal["auto", "default", "flex", "scale", "priority"] + status: ResponseStatus + text: Optional[ResponseTextConfig] = None + top_logprobs: int + truncation: Literal["auto", "disabled"] + usage: Optional[ResponseUsage] = None + user: Optional[str] = None + + @classmethod + def from_request( + cls, + request: ResponsesRequest, + sampling_params: SamplingParams, + model_name: str, + created_time: int, + output: list[ResponseOutputItem], + status: ResponseStatus, + usage: Optional[ResponseUsage] = None, + ) -> "ResponsesResponse": + return cls( + id=request.request_id, + created_at=created_time, + instructions=request.instructions, + metadata=request.metadata, + model=model_name, + output=output, + parallel_tool_calls=request.parallel_tool_calls, + temperature=sampling_params.temperature, + tool_choice=request.tool_choice, + tools=request.tools, + top_p=sampling_params.top_p, + background=request.background, + max_output_tokens=sampling_params.max_tokens, + max_tool_calls=request.max_tool_calls, + previous_response_id=request.previous_response_id, + prompt=request.prompt, + reasoning=request.reasoning, + service_tier=request.service_tier, + status=status, + text=request.text, + top_logprobs=sampling_params.logprobs, + truncation=request.truncation, + user=request.user, + usage=usage, + ) + + +class ResponsesStreamResponse(OpenAIBaseModel): + response: ResponsesResponse + sequence_number: int + type: Literal["response.created", "response.in_progress", + "response.completed", "response.failed", + "response.incomplete"] + + def encode_opaque_state(opaque_state: Optional[bytes]) -> Optional[str]: if opaque_state is None: return None diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 6b8bcbb14ad..dffcd19b196 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -41,12 +41,20 @@ CompletionResponse, CompletionResponseChoice, ErrorResponse, ModelCard, - ModelList, UsageInfo, + ModelList, ResponsesRequest, + UsageInfo, to_llm_disaggregated_params) from tensorrt_llm.serve.postprocess_handlers import ( ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor, chat_stream_post_processor, completion_response_post_processor, completion_stream_post_processor) +from tensorrt_llm.serve.responses_utils import ConversationHistoryStore +from tensorrt_llm.serve.responses_utils import \ + create_response as responses_api_create_response +from tensorrt_llm.serve.responses_utils import \ + process_streaming_events as responses_api_process_streaming_events +from tensorrt_llm.serve.responses_utils import \ + request_preprocess as responses_api_request_preprocess from tensorrt_llm.version import __version__ as VERSION from .._utils import nvtx_mark, set_prometheus_multiproc_dir @@ -83,6 +91,12 @@ def __init__(self, logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path) self.model_config = None + # Enable response storage for Responses API + self.enable_store = True + if len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) > 0: + self.enable_store = False + self.conversation_store = ConversationHistoryStore() + model_dir = Path(model) if model_dir.exists() and model_dir.is_dir(): self.model = model_dir.name @@ -166,6 +180,20 @@ def create_error_response( return JSONResponse(content=error_response.model_dump(), status_code=error_response.code) + def _create_invalid_response_id_error(self, response_id: str) -> Response: + return self.create_error_response( + err_type="InvalidRequestError", + message=(f"Invalid 'response_id': '{response_id}'. " + "Expected an ID that begins with 'resp'."), + ) + + def _create_response_id_not_found_error(self, response_id: str) -> Response: + return self.create_error_response( + err_type="InvalidRequestError", + message=f"Response with id '{response_id}' not found.", + status_code=HTTPStatus.NOT_FOUND, + ) + def register_routes(self): self.app.add_api_route("/health", self.health, methods=["GET"]) self.app.add_api_route("/health_generate", self.health_generate, methods=["GET"]) @@ -182,6 +210,9 @@ def register_routes(self): self.app.add_api_route("/v1/chat/completions", self.openai_chat if not self.use_harmony else self.chat_harmony, methods=["POST"]) + self.app.add_api_route("/v1/responses", + self.openai_responses, + methods=["POST"]) if self.llm.args.return_perf_metrics: # register /prometheus/metrics self.mount_metrics() @@ -748,6 +779,80 @@ async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Reques logger.debug("Error details: %s", traceback.format_exc()) return self.create_error_response(message=str(e), err_type="internal_error") + async def openai_responses(self, request: ResponsesRequest, raw_request: Request) -> Response: + async def create_stream_response(generator, request: ResponsesRequest, sampling_params) -> AsyncGenerator[str, None]: + async for event_data in responses_api_process_streaming_events( + request=request, + sampling_params=sampling_params, + generator=generator, + harmony_adapter=self.harmony_adapter, + model_name=self.model, + conversation_store=self.conversation_store, + enable_store=self.enable_store + ): + yield event_data + + try: + if not self.use_harmony: + raise NotImplementedError("Responses API only supports harmony format for now") + + # Initialize HarmonyAdapter + # NOTE: WAR for Disagg failure, may affect perf if no warmup + if not self.harmony_adapter: + self.harmony_adapter = HarmonyAdapter() + + if request.background: + logger.warning("Request.background is not supported yet, will fallback to foreground processing.") + + # Get prev response + prev_response = None + if self.enable_store: + prev_response_id = request.previous_response_id + if prev_response_id is not None: + if not prev_response_id.startswith("resp_"): + return self._create_invalid_response_id_error(prev_response_id) + + prev_response = await self.conversation_store.load_response(prev_response_id) + if prev_response is None: + logger.debug(f"response_id {prev_response_id} not found") + return self._create_response_id_not_found_error(prev_response_id) + + input_tokens, sampling_params = await responses_api_request_preprocess( + request, prev_response, self.harmony_adapter, self.conversation_store, self.enable_store) + + promise = self.llm.generate_async( + inputs=input_tokens, + sampling_params=sampling_params, + streaming=request.stream, + ) + + asyncio.create_task(self.await_disconnected(raw_request, promise)) + + if request.stream: + return StreamingResponse( + create_stream_response(promise, request, sampling_params), + media_type="text/event-stream" + ) + else: + return await responses_api_create_response( + generator=promise, + request=request, + sampling_params=sampling_params, + model_name=self.model, + conversation_store=self.conversation_store, + generation_result=None, + enable_store=self.enable_store) + except CppExecutorError: + logger.error(traceback.format_exc()) + # If internal executor error is raised, shutdown the server + signal.raise_signal(signal.SIGINT) + except Exception as e: + logger.error(traceback.format_exc()) + return self.create_error_response(str(e)) + + return JSONResponse(content={"detail": "None"}) + + async def __call__(self, host, port): # Store the binding address for server registration self.binding_addr = f"http://{host}:{port}" diff --git a/tensorrt_llm/serve/responses_utils.py b/tensorrt_llm/serve/responses_utils.py new file mode 100644 index 00000000000..d4a6af268c4 --- /dev/null +++ b/tensorrt_llm/serve/responses_utils.py @@ -0,0 +1,848 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import json +import os +import time +import uuid +from collections.abc import AsyncGenerator +from copy import copy +from typing import Literal, Optional, OrderedDict, Union + +# yapf: disable +from openai.types.responses import (ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseFunctionToolCall, + ResponseInProgressEvent, ResponseOutputItem, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, ResponseOutputText, + ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseTextDeltaEvent, + ResponseTextDoneEvent) +# yapf: enable +from openai.types.responses.response_function_web_search import ( + ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch) +from openai.types.responses.response_reasoning_item import Content +from openai.types.responses.tool import Tool +from openai_harmony import (Author, Conversation, DeveloperContent, + HarmonyEncodingName, Message, ReasoningEffort, Role, + StreamState, SystemContent, TextContent, + ToolDescription, load_harmony_encoding) + +from tensorrt_llm.llmapi import SamplingParams +from tensorrt_llm.llmapi.llm import RequestOutput +from tensorrt_llm.logger import logger +from tensorrt_llm.serve.openai_protocol import (OpenAIBaseModel, + ResponseInputOutputItem, + ResponsesRequest, + ResponsesResponse) + +from .harmony_adapter import HarmonyAdapter + +REASONING_EFFORT = { + "high": ReasoningEffort.HIGH, + "medium": ReasoningEffort.MEDIUM, + "low": ReasoningEffort.LOW, +} + +ENABLE_RESPONSES_DEBUG_MSG = False + + +def responses_debug_log(msg): + if ENABLE_RESPONSES_DEBUG_MSG: + logger.debug(msg) + + +_harmony_encoding = None + + +def random_uuid(): + return str(uuid.uuid4().hex) + + +def get_encoding(): + global _harmony_encoding + if _harmony_encoding is None: + _harmony_encoding = load_harmony_encoding( + HarmonyEncodingName.HARMONY_GPT_OSS) + return _harmony_encoding + + +def decode_tokens(tokens): + return get_encoding().decode(tokens) + + +def parse_response_input( + input_msg: ResponseInputOutputItem, + prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]] +) -> Message: + if not isinstance(input_msg, dict): + input_msg = input_msg.model_dump() + + responses_debug_log(f"------- Parsing input -----------") + responses_debug_log(input_msg) + responses_debug_log("") + + if "type" not in input_msg or input_msg["type"] == "message": + role = input_msg["role"] + content = input_msg["content"] + if role == "system": + # User is trying to set a system message. Change it to: + # <|start|>developer<|message|># Instructions + # {instructions}<|end|> + role = "developer" + text_prefix = "Instructions:\n" + else: + text_prefix = "" + if isinstance(content, str): + msg = Message.from_role_and_content(role, text_prefix + content) + elif isinstance(content, list): + contents = [ + TextContent(text=text_prefix + c["text"]) for c in content + ] + msg = Message.from_role_and_contents(role, contents) + else: + logger.warning("Responses API: Invalid input message type") + msg = None + elif input_msg["type"] == "function_call_output": + call_id = input_msg["call_id"] + call_response: Optional[ResponseFunctionToolCall] = None + for prev_response in reversed(prev_responses): + if isinstance(prev_response, ResponseFunctionToolCall + ) and prev_response.call_id == call_id: + call_response = prev_response + break + if call_response is None: + raise ValueError(f"No call message found for {call_id}") + msg = Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{call_response.name}"), + input_msg["output"]) + elif input_msg["type"] == "reasoning": + content = input_msg["content"] + assert len(content) == 1 + msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"]) + elif input_msg["type"] == "function_call": + msg = Message.from_role_and_content(Role.ASSISTANT, + input_msg["arguments"]) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(f"functions.{input_msg['name']}") + msg = msg.with_content_type("json") + else: + raise ValueError(f"Unknown input type: {input_msg['type']}") + return msg + + +class ConversationHistoryStore: + + def __init__(self, resp_capacity: int = 16, max_conversations=32): + self.response_capacity = resp_capacity + self.conversation_capacity = resp_capacity * 4 + self.max_conversations = max_conversations + + self.responses_lock = asyncio.Lock() + self.responses: OrderedDict[str, ResponsesResponse] = OrderedDict() + + self.conversations_lock = asyncio.Lock() + self.conversations: OrderedDict[str, list[Message]] = OrderedDict() + self.response_to_conversation: dict[str, str] = {} + self.conversation_to_response: dict[str, str] = {} + + async def load_response(self, resp_id: str) -> ResponsesResponse: + responses_debug_log(f"ConversationHistoryStore loading resp: {resp_id}") + async with self.responses_lock: + return self.responses.get(resp_id) + + async def store_response(self, + resp: ResponsesResponse, + resp_msgs: Optional[list[Message]] = [], + prev_resp_id: Optional[str] = None) -> None: + resp_id = resp.id + responses_debug_log(f"ConversationHistoryStore storing resp: {resp_id}") + async with self.responses_lock: + self.responses[resp_id] = resp + if len(self.responses) > self.response_capacity: + self._pop_response() + + async with self.conversations_lock: + conversation_id: str + if resp_id in self.response_to_conversation: + conversation_id = self.response_to_conversation[resp_id] + self.conversations[conversation_id].extend(resp_msgs) + elif prev_resp_id is not None: + conversation_id = self.response_to_conversation[prev_resp_id] + self.conversations[conversation_id].extend(resp_msgs) + while len(self.conversations[conversation_id] + ) > self.conversation_capacity: + self._pop_conversation(resp_id) + else: + conversation_id = random_uuid() + self.conversations[conversation_id] = resp_msgs + + responses_debug_log( + f" * storing at conversation id: {conversation_id}") + + self.response_to_conversation[resp_id] = conversation_id + self.conversation_to_response[conversation_id] = resp_id + self._update_visited_conversation(conversation_id) + + async def store_messages(self, resp_id: str, msgs: list[Message], + prev_resp_id: Optional[str]): + responses_debug_log(f"ConversationHistoryStore storing msg:") + for msg in msgs: + responses_debug_log(f" -> {msg.to_json()}") + + async with self.conversations_lock: + conversation_id: str + if prev_resp_id is not None: + conversation_id = self.response_to_conversation[prev_resp_id] + else: + conversation_id = random_uuid() + + responses_debug_log( + f" * storing at conversation: {conversation_id}") + self.conversations[conversation_id] = msgs + if len(self.conversations[conversation_id] + ) > self.conversation_capacity: + self._pop_conversation(resp_id) + + self.response_to_conversation[resp_id] = conversation_id + self.conversation_to_response[conversation_id] = resp_id + self._update_visited_conversation(conversation_id) + + async def append_messages(self, resp_id: str, msgs: list[Message]): + responses_debug_log(f"ConversationHistoryStore appending msgs:") + for msg in msgs: + responses_debug_log(f" -> {msg.to_json()}") + + async with self.conversations_lock: + assert resp_id in self.response_to_conversation + conversation_id = self.response_to_conversation[resp_id] + + responses_debug_log( + f" * appending at conversation: {conversation_id}") + self.conversations[conversation_id].extend(msgs) + if len(self.conversations[conversation_id] + ) > self.conversation_capacity: + self._pop_conversation(resp_id) + self._update_visited_conversation(conversation_id) + + async def get_conversation_history(self, resp_id: str) -> list[Message]: + responses_debug_log(f"ConversationHistoryStore getting prev_msgs:") + responses_debug_log(f" -> prev_resp_id: {resp_id}") + async with self.conversations_lock: + if resp_id in self.response_to_conversation: + conversation_id = self.response_to_conversation[resp_id] + self._update_visited_conversation(conversation_id) + return self.conversations.get(conversation_id, []) + + return [] + + def _update_visited_conversation(self, conversation_id) -> None: + if conversation_id not in self.conversations: + return + + self.conversations.move_to_end(conversation_id) + if len(self.conversations) > self.max_conversations: + removed_id, _ = self.conversations.popitem(last=False) + responses_debug_log( + f"ConversationHistoryStore Removing conversation {removed_id}") + removed_resp_id = self.conversation_to_response[removed_id] + # The responses may have been removed due to response capacity + if removed_resp_id in self.response_to_conversation: + self.response_to_conversation.pop(removed_resp_id) + self.conversation_to_response.pop(removed_id) + + def _pop_conversation(self, resp_id) -> None: + conversation_id = self.response_to_conversation.get(resp_id, None) + if conversation_id is None: + return + + conversation = self.conversations[conversation_id] + first_conversation_range = [] + for i, msg in enumerate(conversation): + if msg.author.role == Role.USER: + first_conversation_range.append(i) + elif msg.channel == "final": + first_conversation_range.append(i) + break + del conversation[ + first_conversation_range[0]:first_conversation_range[1] + 1] + + def _pop_response(self) -> None: + responses_debug_log(f"responses type: {type(self.responses)}") + resp_id, _ = self.responses.popitem(last=False) + if resp_id in self.response_to_conversation: + self.response_to_conversation.pop(resp_id) + + +def get_system_message( + model_identity: Optional[str] = None, + reasoning_effort: Optional[Literal["high", "medium", "low"]] = None, + start_date: Optional[str] = None, + browser_description: Optional[str] = None, + python_description: Optional[str] = None, +) -> Message: + sys_msg_content = SystemContent.new() + if model_identity is not None: + sys_msg_content = sys_msg_content.with_model_identity(model_identity) + if reasoning_effort is not None: + sys_msg_content = sys_msg_content.with_reasoning_effort( + REASONING_EFFORT[reasoning_effort]) + if start_date: + sys_msg_content = sys_msg_content.with_conversation_start_date( + start_date) + if browser_description is not None: + sys_msg_content = sys_msg_content.with_tools(browser_description) + if python_description is not None: + sys_msg_content = sys_msg_content.with_tools(python_description) + sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) + return sys_msg + + +def get_developer_message(instructions: Optional[str] = None, + tools: Optional[list[Tool]] = None) -> Message: + dev_msg_content = DeveloperContent.new() + if instructions is not None: + dev_msg_content = dev_msg_content.with_instructions(instructions) + if tools is not None: + function_tools = [] + for tool in tools: + if tool.type in ("web_search_preview", "code_interpreter"): + # These are built-in tools that are added to the system message. + pass + elif tool.type == "function": + function_tools.append(tool) + else: + raise ValueError(f"tool type {tool.type} not supported") + if function_tools: + function_tool_descriptions = [ + ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) for tool in function_tools + ] + dev_msg_content = dev_msg_content.with_function_tools( + function_tool_descriptions) + dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content) + return dev_msg + + +def get_user_message(content: str) -> Message: + return Message.from_role_and_content(Role.USER, content) + + +def construct_harmony_messages( + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + prev_msgs: list[Message] = [], +) -> list[Message]: + """Construct messages from request input, includes conversation history messages if exists.""" + messages: list[Message] = [] + if prev_response is None: + # New conversation. + reasoning_effort = (request.reasoning.effort + if request.reasoning else None) + sys_msg = get_system_message(reasoning_effort=reasoning_effort, ) + messages.append(sys_msg) + dev_msg = get_developer_message(request.instructions, request.tools) + messages.append(dev_msg) + else: + messages.extend(prev_msgs) + # Append the new input. + # Responses API supports simple text inputs without chat format. + if isinstance(request.input, str): + messages.append(get_user_message(request.input)) + else: + if prev_response is not None: + prev_outputs = copy(prev_response.output) + else: + prev_outputs = [] + for input_msg in request.input: + msg = parse_response_input(input_msg, prev_outputs) + if msg is not None: + messages.append(msg) + # User passes in a a tool call request and its output. We need + # to add the tool call request to prev_outputs so that the + # parse_response_input can find the tool call request when + # parsing the tool call output. + if isinstance(input_msg, ResponseFunctionToolCall): + prev_outputs.append(input_msg) + return messages + + +def render_for_completion(messages: list[Message]) -> list[int]: + conversation = Conversation.from_messages(messages) + responses_debug_log("Rendering conversation:") + responses_debug_log(conversation.to_json()) + token_ids = get_encoding().render_conversation_for_completion( + conversation, Role.ASSISTANT) + return token_ids + + +def parse_output_tokens(tokens: list[int]) -> list[Message]: + return get_encoding().parse_messages_from_completion_tokens( + tokens, role=Role.ASSISTANT) + + +def parse_output_message(message: Message) -> list[ResponseOutputItem]: + """ + Parse a Harmony message into a list of output response items. + """ + if message.author.role != "assistant": + # This is a message from a tool to the assistant (e.g., search result). + # Don't include it in the final output for now. This aligns with + # OpenAI's behavior on models like o4-mini. + return [] + + output_items: list[ResponseOutputItem] = [] + recipient = message.recipient + if recipient is not None and recipient.startswith("browser."): + if len(message.content) != 1: + raise ValueError("Invalid number of contents in browser message") + content = message.content[0] + browser_call = json.loads(content.text) + # TODO: translate to url properly! + if recipient == "browser.search": + action = ActionSearch( + query=f"cursor:{browser_call.get('query', '')}", type="search") + elif recipient == "browser.open": + action = ActionOpenPage(url=f"cursor:{browser_call.get('url', '')}", + type="open_page") + elif recipient == "browser.find": + action = ActionFind(pattern=browser_call["pattern"], + url=f"cursor:{browser_call.get('url', '')}", + type="find") + else: + raise ValueError(f"Unknown browser action: {recipient}") + web_search_item = ResponseFunctionWebSearch( + id=f"ws_{random_uuid()}", + action=action, + status="completed", + type="web_search_call", + ) + output_items.append(web_search_item) + elif message.channel == "analysis": + for content in message.content: + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + summary=[], + type="reasoning", + content=[Content(text=content.text, type="reasoning_text")], + status=None, + ) + output_items.append(reasoning_item) + elif message.channel == "commentary": + if message.recipient is None: + pass + elif message.recipient.startswith("functions."): + function_name = message.recipient.split(".")[-1] + for content in message.content: + random_id = random_uuid() + response_item = ResponseFunctionToolCall( + arguments=content.text, + call_id=f"call_{random_id}", + type="function_call", + name=function_name, + id=f"fc_{random_id}", + ) + output_items.append(response_item) + elif message.recipient.startswith( + "python") or message.recipient.startswith("browser"): + for content in message.content: + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + summary=[], + type="reasoning", + content=[Content(text=content.text, type="reasoning_text")], + status=None, + ) + output_items.append(reasoning_item) + else: + raise ValueError(f"Unknown recipient: {message.recipient}") + elif message.channel == "final": + contents = [] + for content in message.content: + output_text = ResponseOutputText( + text=content.text, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + contents.append(output_text) + text_item = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=contents, + role=message.author.role, + status="completed", + type="message", + ) + output_items.append(text_item) + else: + raise ValueError(f"Unknown channel: {message.channel}") + return output_items + + +def finish_reason_mapping(finish_reason: str) -> str: + match finish_reason: + case 'stop': + return 'completed' + case 'length': + return 'incomplete' + case 'timeout': + return 'failed' + case 'cancelled': + return 'cancelled' + + raise RuntimeError("Should never reach here!") + + +async def request_preprocess(request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + harmony_adapter: HarmonyAdapter, + conversation_store: ConversationHistoryStore, + enable_store=False): + # TODO: fix default_max_tokens + sampling_params = request.to_sampling_params( + default_max_tokens=int(16384), + default_sampling_params={ + "stop_token_ids": harmony_adapter.get_stop_tokens() + }) + + prev_response_id = request.previous_response_id + + # TODO: better way to enable metrics + if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0: + sampling_params.return_perf_metrics = True + + prev_msgs = [] + if enable_store: + prev_msgs = await conversation_store.get_conversation_history( + prev_response_id) + + responses_debug_log(f"Prev msgs:") + for msg in prev_msgs: + responses_debug_log(f" -> {msg.to_json()}") + + messages = construct_harmony_messages(request, + prev_response, + prev_msgs=prev_msgs) + + if enable_store and request.store: + # Remove reasoning messages to save token usage during multi-turn conversation + msgs_to_store = [msg for msg in messages if msg.channel != "analysis"] + await conversation_store.store_messages(request.request_id, + msgs_to_store, prev_response_id) + + input_tokens = render_for_completion(messages) + + responses_debug_log("======= Complete Inputs to model =======") + responses_debug_log(decode_tokens(input_tokens)) + responses_debug_log("========================================") + return input_tokens, sampling_params + + +async def create_response( + generator, + request: ResponsesRequest, + sampling_params, + model_name: str, + conversation_store: ConversationHistoryStore, + generation_result: RequestOutput = None, + enable_store=False, + create_time: int = None, +) -> ResponsesResponse: + + final_res: Optional[RequestOutput] = None + response_creation_time = create_time if create_time is not None else int( + time.time()) + prev_response_id = request.previous_response_id + + if generation_result is not None: + final_res = generation_result + else: + final_res = await generator + + if final_res is None: + raise RuntimeError("No output generated or provided") + + responses_debug_log("================================================") + responses_debug_log("RAW MODEL OUTPUT:") + responses_debug_log(final_res.outputs) + responses_debug_log("================================================") + + output_messages = parse_output_tokens(final_res.outputs[0].token_ids) + + responses_debug_log(f"output messages: {len(output_messages)}") + for msg in output_messages: + responses_debug_log(f" -> {msg.to_json()}") + + # prepare responses output + output_content = [] + for msg in output_messages: + output_content.extend(parse_output_message(msg)) + + response = ResponsesResponse.from_request( + request=request, + sampling_params=sampling_params, + model_name=model_name, + created_time=response_creation_time, + output=output_content, + status=finish_reason_mapping(final_res.outputs[0].finish_reason), + ) + + if enable_store and request.store: + await conversation_store.store_response(resp=response, + resp_msgs=output_messages, + prev_resp_id=prev_response_id) + + responses_debug_log("========== Response ===========") + responses_debug_log(response) + responses_debug_log("===============================") + return response + + +async def process_streaming_events( + request: ResponsesRequest, + sampling_params: SamplingParams, + generator, + harmony_adapter: HarmonyAdapter, + model_name: str, + conversation_store: ConversationHistoryStore, + create_time: int = None, + enable_store=False) -> AsyncGenerator[str, None]: + sequence_number = 0 + response_creation_time = create_time if create_time is not None else int( + time.time()) + final_res: Optional[RequestOutput] = None + + def _send_event(event: OpenAIBaseModel): + nonlocal sequence_number + # Set sequence_number if the event has this attribute + if hasattr(event, 'sequence_number'): + event.sequence_number = sequence_number + sequence_number += 1 + # Get event type from the event's type field if it exists + event_type = getattr(event, 'type', 'unknown') + return (f"event: {event_type}\n" + f"data: {event.model_dump_json(indent=None)}\n\n") + + current_content_index = 0 # FIXME: this number is never changed + current_output_index = 0 + current_item_id = "" # FIXME: this number is never changed + sent_output_item_added = False + + initial_response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=response_creation_time, + output=[], + status="in_progress", + usage=None, + ).model_dump() + yield _send_event( + ResponseCreatedEvent( + type="response.created", + sequence_number=-1, + response=initial_response, + )) + yield _send_event( + ResponseInProgressEvent( + type="response.in_progress", + sequence_number=-1, + response=initial_response, + )) + + tools = [tool.model_dump() for tool in request.tools] + stream_request_id = f"responses-api-{request.request_id}" + async for res in generator: + final_res = res + output = res.outputs[0] + + messages = harmony_adapter.stateful_stream_harmony_tokens_to_openai_messages( + stream_request_id, output.token_ids_diff, tools, + request.tool_choice) + stream_state = harmony_adapter.get_stream_state(stream_request_id) + assert stream_state is not None + parser = stream_state.get_parser() + + if parser.state == StreamState.EXPECT_START: + current_output_index += 1 + sent_output_item_added = False + + if len(messages) > 0: + previous_item = messages[-1] + if previous_item.recipient is not None: + # Deal with tool call here + pass + elif previous_item.channel == "analysis": + reasoning_item = ResponseReasoningItem( + type="reasoning", + content=[ + Content( + text=previous_item.content[0].text, + type="reasoning_text", + ), + ], + status="completed", + id=current_item_id, + summary=[], + ) + yield _send_event( + ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + item_id=current_item_id, + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=previous_item.content[0].text, + )) + yield _send_event( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=reasoning_item, + )) + elif previous_item.channel == "final": + text_content = ResponseOutputText( + type="output_text", + text=previous_item.content[0].text, + annotations=[], + ) + yield _send_event( + ResponseTextDoneEvent( + type="response.output_text.done", + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=previous_item.content[0].text, + logprobs=[], + item_id=current_item_id, + )) + yield _send_event( + ResponseContentPartDoneEvent( + type="response.content_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=text_content, + )) + yield _send_event( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[text_content], + status="completed", + ), + )) + + if parser.last_content_delta: + if (parser.current_channel == "final" + and parser.current_recipient is None): + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[], + status="in_progress", + ), + )) + yield _send_event( + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + )) + yield _send_event( + ResponseTextDeltaEvent( + type="response.output_text.delta", + sequence_number=-1, + content_index=current_content_index, + output_index=current_output_index, + item_id=current_item_id, + delta=parser.last_content_delta, + # TODO, use logprobs from ctx.last_request_output + logprobs=[], + )) + elif (parser.current_channel == "analysis" + and parser.current_recipient is None): + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseReasoningItem( + type="reasoning", + id=current_item_id, + summary=[], + status="in_progress", + ), + )) + yield _send_event( + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + )) + yield _send_event( + ResponseReasoningTextDeltaEvent( + type="response.reasoning_text.delta", + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + delta=parser.last_content_delta, + sequence_number=-1, + )) + + # TODO(JunyiXu-nv): support built-in tools(python/browser/code interpreter) + + final_response = await create_response(generator, request, sampling_params, + model_name, conversation_store, + final_res, enable_store, + response_creation_time) + + yield _send_event( + ResponseCompletedEvent( + type="response.completed", + sequence_number=-1, + response=final_response.model_dump(), + )) diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 02454cdc607..eefc1667a6a 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1513,6 +1513,13 @@ def test_openai_chat_harmony(llm_root, llm_venv): str(test_root / "_test_openai_chat_harmony.py")]) +def test_openai_responses(llm_root, llm_venv): + test_root = unittest_path() / "llmapi" / "apps" + llm_venv.run_cmd( + ["-m", "pytest", + str(test_root / "_test_openai_responses.py")]) + + def test_openai_prometheus(llm_root, llm_venv): test_root = unittest_path() / "llmapi" / "apps" llm_venv.run_cmd( diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 10cac3033ca..8c4139185a8 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -104,6 +104,7 @@ l0_h100: - test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test - test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B] - test_e2e.py::test_openai_chat_harmony + - test_e2e.py::test_openai_responses - test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] # ------------- AutoDeploy tests --------------- - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype diff --git a/tests/unittest/llmapi/apps/_test_openai_responses.py b/tests/unittest/llmapi/apps/_test_openai_responses.py new file mode 100644 index 00000000000..beaa805383c --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_openai_responses.py @@ -0,0 +1,241 @@ +import json + +import openai +import pytest +from openai.types.responses import (ResponseCompletedEvent, + ResponseReasoningTextDeltaEvent, + ResponseTextDeltaEvent) + +from ..test_llm import get_model_path +from .openai_server import RemoteOpenAIServer + +pytestmark = pytest.mark.threadleak(enabled=False) + + +@pytest.fixture(scope="module", ids=["GPT-OSS-20B"]) +def model(): + return "gpt_oss/gpt-oss-20b/" + + +@pytest.fixture(scope="module") +def server(model: str): + model_path = get_model_path(model) + with RemoteOpenAIServer(model_path) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer): + return server.get_async_client() + + +def check_reponse(response, prefix=""): + reasoning_exist, message_exist = False, False + for output in response.output: + if output.type == "reasoning": + reasoning_exist = True + elif output.type == "message": + message_exist = True + + assert reasoning_exist, f"{prefix}Reasoning content not exists!" + assert message_exist, f"{prefix}Message content not exists!" + + +def check_tool_calling(response, first_resp=True, prefix=""): + reasoning_exist, tool_call_exist, message_exist = False, False, False + function_call = None + for output in response.output: + if output.type == "reasoning": + reasoning_exist = True + elif output.type == "function_call": + tool_call_exist = True + function_call = output + elif output.type == "message": + message_exist = True + + if first_resp: + assert reasoning_exist and tool_call_exist, f"{prefix}Invalid tool calling 1st response" + assert not message_exist, f"{prefix}Invalid tool calling 1st response" + + return function_call + else: + assert reasoning_exist and message_exist, f"{prefix}Invalid tool calling 2nd response" + assert not tool_call_exist, f"{prefix}Invalid tool calling 2nd response" + + +@pytest.mark.asyncio(loop_scope="module") +async def test_reasoning(client: openai.AsyncOpenAI, model: str): + response = await client.responses.create( + model=model, input="Which one is larger as numeric, 9.9 or 9.11?") + + check_reponse(response, "test_reasoning: ") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str): + for effort in ["low", "medium", "high"]: + response = await client.responses.create( + model=model, + instructions="Use less than 1024 tokens for reasoning", + input="Which one is larger as numeric, 9.9 or 9.11?", + reasoning={"effort": effort}) + check_reponse(response, f"test_reasoning_effort_{effort}: ") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_chat(client: openai.AsyncOpenAI, model: str): + response = await client.responses.create(model=model, + input=[{ + "role": + "developer", + "content": + "Respond in Chinese." + }, { + "role": "user", + "content": "Hello!" + }, { + "role": + "assistant", + "content": + "Hello! How can I help you?" + }, { + "role": "user", + "content": "Tell me a joke." + }]) + check_reponse(response, "test_chat: ") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_multi_turn_chat(client: openai.AsyncOpenAI, model: str): + response = await client.responses.create(model=model, + input="What is the answer of 1+1?") + check_reponse(response, "test_multi_turn_chat_1: ") + + response_2 = await client.responses.create( + model=model, + input="What is the answer of previous question?", + previous_response_id=response.id) + check_reponse(response_2, "test_multi_turn_chat_2: ") + + +def get_current_weather(location: str, format: str = "celsius") -> dict: + return {"sunny": True, "temperature": 20 if format == "celsius" else 68} + + +@pytest.mark.asyncio(loop_scope="module") +async def test_tool_calls(client: openai.AsyncOpenAI, model: str): + tool_get_current_weather = { + "type": "function", + "name": "get_current_weather", + "description": "Gets the current weather in the provided location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "description": "default: celsius", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + } + } + messages = [{"role": "user", "content": "What is the weather like in SF?"}] + response = await client.responses.create( + model=model, + input=messages, + tools=[tool_get_current_weather], + ) + messages.extend(response.output) + function_call = check_tool_calling(response, True, "test_tool_calls: ") + + assert function_call.name == "get_current_weather" + + args = json.loads(function_call.arguments) + answer = get_current_weather(**args) + messages.append({ + "type": "function_call_output", + "call_id": function_call.call_id, + "output": json.dumps(answer), + }) + + response = await client.responses.create(model=model, + input=messages, + tools=[tool_get_current_weather]) + + check_tool_calling(response, False, "test_tool_calls: ") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming(client: openai.AsyncOpenAI, model: str): + stream = await client.responses.create( + model=model, + input="Explain the theory of relativity in brief.", + stream=True, + ) + + reasoning_deltas, message_deltas = list(), list() + async for event in stream: + if isinstance(event, ResponseTextDeltaEvent): + message_deltas.append(event.delta) + elif isinstance(event, ResponseReasoningTextDeltaEvent): + reasoning_deltas.append(event.delta) + + full_response = "".join(message_deltas) + full_reasoning_response = "".join(reasoning_deltas) + assert full_response + assert full_reasoning_response + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str): + tool_get_current_weather = { + "type": "function", + "name": "get_current_weather", + "description": "Gets the current weather in the provided location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "description": "default: celsius", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + } + } + messages = [{"role": "user", "content": "What is the weather like in SF?"}] + stream = await client.responses.create( + model=model, + input=messages, + tools=[tool_get_current_weather], + stream=True, + ) + + function_call = None + reasoning_deltas = list() + async for event in stream: + if isinstance(event, ResponseCompletedEvent): + for output in event.response.output: + if output.type == "function_call": + function_call = output + elif isinstance(event, ResponseReasoningTextDeltaEvent): + reasoning_deltas.append(event.delta) + + reasoning = "".join(reasoning_deltas) + tool_args = json.loads(function_call.arguments) + + assert function_call.name == "get_current_weather", "wrong function calling name" + assert tool_args, "tool args not exists!" + assert reasoning, "reasoning not exists!" + + get_current_weather(**tool_args) From 040a8953fdd4673377cb10c08c8ea60b38161051 Mon Sep 17 00:00:00 2001 From: JunyiXu-nv Date: Mon, 8 Sep 2025 11:11:35 +0800 Subject: [PATCH 2/3] [TRTLLM-7779][feat] Support multiple postprocess workers for chat completions API (#7508) Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com> Co-authored-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> --- tensorrt_llm/serve/harmony_adapter.py | 156 ++++++++-------- tensorrt_llm/serve/openai_server.py | 62 +++++-- tensorrt_llm/serve/postprocess_handlers.py | 166 +++++++++++++----- .../llmapi/apps/_test_openai_chat_harmony.py | 8 + 4 files changed, 261 insertions(+), 131 deletions(-) diff --git a/tensorrt_llm/serve/harmony_adapter.py b/tensorrt_llm/serve/harmony_adapter.py index a46e7c5ed45..2949965d729 100644 --- a/tensorrt_llm/serve/harmony_adapter.py +++ b/tensorrt_llm/serve/harmony_adapter.py @@ -6,7 +6,7 @@ import time import traceback import uuid -from typing import Any, AsyncGenerator, Literal +from typing import Any, List, Literal from openai_harmony import (Author, Conversation, DeveloperContent, HarmonyEncodingName, HarmonyError, Message, @@ -14,15 +14,15 @@ SystemContent, TextContent, ToolDescription, load_harmony_encoding) -from tensorrt_llm.llmapi import RequestOutput from tensorrt_llm.logger import logger # yapf: disable -from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionRequest, +from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, + ChatCompletionStreamResponse, + ChatCompletionToolsParam, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, UsageInfo) @@ -1485,36 +1485,72 @@ def _is_tool_call_allowed(self, tool_call: dict[str, Any], return True -async def handle_streaming_response( - harmony_adapter: HarmonyAdapter, - generator: RequestOutput, - request_id: str, - request: ChatCompletionRequest, -) -> AsyncGenerator[str, None]: - """Handle streaming response with harmony format.""" +_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None + + +def get_harmony_adapter(): + global _SERVE_HARMONY_ADAPTER + if _SERVE_HARMONY_ADAPTER is None: + _SERVE_HARMONY_ADAPTER = HarmonyAdapter() + + return _SERVE_HARMONY_ADAPTER + + +def handle_streaming_response(tools: List[ChatCompletionToolsParam], + tool_choice: str, outputs: List, model: str, + request_id: str, done: bool, + num_prompt_tokens: int): first_iteration = True - async for res in generator: - output = res.outputs[0] + output = outputs[0] - # Convert tools to dictionary format for harmony adapter (standard pattern) - tools_dict = None - if request.tools: - tools_dict = [tool.model_dump() for tool in request.tools] + # Convert tools to dictionary format for harmony adapter (standard pattern) + tools_dict = None + harmony_adapter = get_harmony_adapter() + if tools: + tools_dict = [tool.model_dump() for tool in tools] - # Get tool_choice from request - if "none", don't pass tools to parser - tool_choice = getattr(request, 'tool_choice', None) - if tool_choice == "none": - tools_for_parser = None - else: - tools_for_parser = tools_dict + # Get tool_choice from request - if "none", don't pass tools to parser + if tool_choice == "none": + tools_for_parser = None + else: + tools_for_parser = tools_dict - # Create OpenAI streaming responses - try: + # Create OpenAI streaming responses + try: + res = [] + if done: + # Clean up state + harmony_adapter.cleanup_stream_state(request_id) + + usage_info = _create_usage_info(num_prompt_tokens, outputs) + + # Send final message with finish_reason + final_response = ChatCompletionStreamResponse( + model=model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + ], + ) + + final_response_json = final_response.model_dump_json( + exclude_none=True) + final_usage_chunk = ChatCompletionStreamResponse(choices=[], + model=model, + usage=usage_info) + final_usage_json = final_usage_chunk.model_dump_json( + exclude_none=True) + res.append(f"data: {final_response_json}\n\n") + res.append(f"data: {final_usage_json}\n\n") + else: responses = harmony_adapter.create_openai_streaming_response( request_id=request_id, tokens=output.token_ids_diff, available_tools=tools_for_parser, - model_name=request.model, + model_name=model, tool_choice=tool_choice) # Send first response after receiving the first output if first_iteration: @@ -1525,64 +1561,44 @@ async def handle_streaming_response( delta=first_delta) first_response = ChatCompletionStreamResponse( - model=request.model, + model=model, choices=[choice], ) response_json = first_response.model_dump_json( exclude_none=True) - yield f"data: {response_json}\n\n" + res.append(f"data: {response_json}\n\n") - for response in responses: - yield response + res.extend(responses) - except Exception as e: - logger.error(f"Failed to create OpenAI streaming response: {e}") - logger.debug(f"Streaming error details: {traceback.format_exc()}") - # Clean up state - harmony_adapter.cleanup_stream_state(request_id) - raise e - - # Clean up state - harmony_adapter.cleanup_stream_state(request_id) + return res - # Send final message with finish_reason - output = generator.outputs[0] - final_response = ChatCompletionStreamResponse( - model=request.model, - choices=[ - ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(), - finish_reason=output.finish_reason, - stop_reason=output.stop_reason) - ]) + except Exception as e: + logger.error(f"Failed to create OpenAI streaming response: {e}") + logger.debug(f"Streaming error details: {traceback.format_exc()}") + # Clean up state + harmony_adapter.cleanup_stream_state(request_id) + raise e - yield f"data: {final_response.model_dump_json(exclude_unset=True)}\n\n" - yield "data: [DONE]\n\n" - -async def handle_non_streaming_response( - harmony_adapter: HarmonyAdapter, promise: RequestOutput, - request: ChatCompletionRequest) -> ChatCompletionResponse: +def handle_non_streaming_response(tools: List[ChatCompletionToolsParam], + tool_choice: str, outputs: List, model: str, + num_prompt_tokens: int): """Handle non-streaming response with harmony format.""" - # Get final result - await promise - # Parse harmony output to OpenAI format # Convert tools to dictionary format for harmony adapter (standard pattern) tools_dict = None - if request.tools: - tools_dict = [tool.model_dump() for tool in request.tools] + harmony_adapter = get_harmony_adapter() + if tools: + tools_dict = [tool.model_dump() for tool in tools] # Get tool_choice from request - if "none", don't pass tools to parser - tool_choice = getattr(request, 'tool_choice', None) if tool_choice == "none": tools_for_parser = None else: tools_for_parser = tools_dict - output = promise.outputs[0] + output = outputs[0] parsed_output = harmony_adapter.harmony_output_to_openai( output.token_ids, tools_for_parser, tool_choice) @@ -1597,11 +1613,11 @@ async def handle_non_streaming_response( output.finish_reason) # Create usage info from metrics (RequestOutput doesn't have usage in v1) - usage_info = _create_usage_info(promise) + usage_info = _create_usage_info(num_prompt_tokens, outputs) # Create response response = ChatCompletionResponse( - model=request.model, + model=model, choices=[ ChatCompletionResponseChoice( index=0, @@ -1613,7 +1629,6 @@ async def handle_non_streaming_response( # Optional: Log if harmony parsing failed (for debugging) if parsed_output.get('_harmony_parsing_failed'): logger.warning("⚠️ Harmony parsing fell back to raw text decoding") - logger.debug(f"request\n\n{request}") logger.debug(f"response\n\n{response}\n") return response @@ -1646,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any], return reason -def _create_usage_info(final_res: RequestOutput) -> UsageInfo: +def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo: """Create usage info from RequestOutput following serving_chat.py pattern.""" - # Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids - assert final_res.prompt_token_ids is not None - num_prompt_tokens = len(final_res.prompt_token_ids) - # Calculate completion tokens from all outputs - num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) + num_generated_tokens = sum(len(output.token_ids) for output in outputs) # Create usage info usage = UsageInfo(prompt_tokens=num_prompt_tokens, diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index dffcd19b196..0c38328e2b6 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -45,9 +45,10 @@ UsageInfo, to_llm_disaggregated_params) from tensorrt_llm.serve.postprocess_handlers import ( - ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor, - chat_stream_post_processor, completion_response_post_processor, - completion_stream_post_processor) + ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs, + chat_harmony_post_processor, chat_harmony_streaming_post_processor, + chat_response_post_processor, chat_stream_post_processor, + completion_response_post_processor, completion_stream_post_processor) from tensorrt_llm.serve.responses_utils import ConversationHistoryStore from tensorrt_llm.serve.responses_utils import \ create_response as responses_api_create_response @@ -58,8 +59,7 @@ from tensorrt_llm.version import __version__ as VERSION from .._utils import nvtx_mark, set_prometheus_multiproc_dir -from .harmony_adapter import (HarmonyAdapter, handle_non_streaming_response, - handle_streaming_response, +from .harmony_adapter import (HarmonyAdapter, get_harmony_adapter, maybe_transform_reasoning_effort) # yapf: enale @@ -118,7 +118,11 @@ def __init__(self, # gpt-oss self.harmony_adapter: HarmonyAdapter | None = None - self.use_harmony = self.model_config.model_type == "gpt_oss" + disable_harmony = os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1" + if disable_harmony: + self.use_harmony = False + else: + self.use_harmony = (self.model_config.model_type == "gpt_oss") @asynccontextmanager async def lifespan(app: FastAPI): @@ -712,11 +716,35 @@ async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Reques Chat Completion API with harmony format support. Supports both streaming and non-streaming modes. """ + + async def create_harmony_response( + promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse: + await promise.aresult() + if self.postproc_worker_enabled: + chat_response =promise.outputs[0]._postprocess_result + else: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + chat_response = post_processor(promise, args) + + return chat_response + + async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams): + if not self.postproc_worker_enabled: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + + async for res in promise: + pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) + # await self._extract_metrics(res) + for pp_res in pp_results: + yield pp_res + + yield "data: [DONE]\n\n" + try: # Initialize HarmonyAdapter # NOTE: WAR for Disagg failure, may affect perf if no warmup if not self.harmony_adapter: - self.harmony_adapter = HarmonyAdapter() + self.harmony_adapter = get_harmony_adapter() # Convert Pydantic models to dictionaries for JSON serialization (standard pattern) tools_dict = None if request.tools: @@ -751,27 +779,37 @@ async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Reques vocab_size=self.tokenizer.tokenizer.vocab_size) sampling_params.detokenize = False # Harmony adapter handles detokenization + postproc_args = ChatCompletionPostprocArgs.from_request(request) + postproc_params = PostprocParams( + post_processor=chat_harmony_streaming_post_processor + if request.stream else chat_harmony_post_processor, + postproc_args=postproc_args, + ) + # Generate promise = self.llm.generate_async( inputs=harmony_tokens, sampling_params=sampling_params, + _postproc_params=postproc_params if self.postproc_worker_enabled else None, streaming=bool(request.stream), lora_request=request.lora_request, ) + postproc_args.request_id = promise.request_id + + if not self.postproc_worker_enabled: + postproc_args.num_prompt_tokens = len(promise.prompt_token_ids) + # Disconnect cancellation asyncio.create_task(self.await_disconnected(raw_request, promise)) # Handle streaming if request.stream: return StreamingResponse( - handle_streaming_response( - self.harmony_adapter, promise, - str(promise.request_id), request, - ), + content=create_streaming_generator(promise, postproc_params), media_type="text/event-stream" ) else: - response = await handle_non_streaming_response(self.harmony_adapter, promise, request) + response = await create_harmony_response(promise, postproc_params) return JSONResponse(response.model_dump()) except Exception as e: diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index 07db6e27a75..0fbcedb9dac 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -9,6 +9,8 @@ ReasoningParserFactory) from ..llmapi.tokenizer import TransformersTokenizer # yapf: disable +from .harmony_adapter import (handle_non_streaming_response, + handle_streaming_response) from .openai_protocol import (ChatCompletionLogProbs, ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, @@ -24,7 +26,8 @@ FunctionCall, StreamOptions, ToolCall, UsageInfo, to_disaggregated_params) -# yapf: enale +# yapf: enable + @dataclass(kw_only=True) class ChatPostprocArgs(PostprocArgs): @@ -57,8 +60,7 @@ def from_request(cls, request: ChatCompletionRequest): ) -def create_logprobs(token_ids: List[int], - tokenizer: TransformersTokenizer, +def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer, logprobs: List[float]) -> ChatCompletionLogProbs: assert len(token_ids) == len(logprobs), \ "token_ids and logprobs have different lengths" @@ -75,12 +77,14 @@ def create_logprobs(token_ids: List[int], return chat_logprobs -def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, streaming: bool) -> Tuple[bool, str, str]: +def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, + streaming: bool) -> Tuple[bool, str, str]: reasoning_parser = None if args.reasoning_parser is not None: if output_index not in args.reasoning_parser_dict: - args.reasoning_parser_dict[output_index] = ReasoningParserFactory.create_reasoning_parser( - args.reasoning_parser) + args.reasoning_parser_dict[ + output_index] = ReasoningParserFactory.create_reasoning_parser( + args.reasoning_parser) reasoning_parser = args.reasoning_parser_dict[output_index] in_reasoning = False @@ -97,7 +101,8 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, @nvtx_range_debug("chat_stream_post_processor") -def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> List[str]: +def chat_stream_post_processor(rsp: GenerationResultBase, + args: ChatPostprocArgs) -> List[str]: def yield_first_chat(num_tokens: int, idx: int, @@ -128,9 +133,13 @@ def yield_first_chat(num_tokens: int, include_continuous_usage = False if args.first_iteration: for i in range(args.num_choices): - res.append(f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n") + res.append( + f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n" + ) if args.echo and args.last_message_content: - res.append(f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n") + res.append( + f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n" + ) args.first_iteration = False for output in rsp.outputs: @@ -158,14 +167,18 @@ def yield_first_chat(num_tokens: int, delta_message = DeltaMessage( content=delta_text, reasoning_content=reasoning_delta_text) - choice = ChatCompletionResponseStreamChoice(index=i, - delta=delta_message, - finish_reason=None, - avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None)) + choice = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + finish_reason=None, + avg_decoded_tokens_per_iter=getattr(rsp, + 'avg_decoded_tokens_per_iter', + None)) if args.return_logprobs: logprobs = output.logprobs_diff token_ids = output.token_ids_diff - choice.logprobs = create_logprobs(token_ids, args.tokenizer, logprobs) + choice.logprobs = create_logprobs(token_ids, args.tokenizer, + logprobs) if output.finish_reason is not None: choice.finish_reason = output.finish_reason choice.stop_reason = output.stop_reason @@ -179,57 +192,62 @@ def yield_first_chat(num_tokens: int, res.append(f"data: {data}\n\n") if include_usage and rsp._done: - completion_tokens = sum(output.length - for output in rsp.outputs) + completion_tokens = sum(output.length for output in rsp.outputs) final_usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) - final_usage_chunk = ChatCompletionStreamResponse( - choices=[], model=args.model, usage=final_usage) + final_usage_chunk = ChatCompletionStreamResponse(choices=[], + model=args.model, + usage=final_usage) final_usage_data = final_usage_chunk.model_dump_json() res.append(f"data: {final_usage_data}\n\n") return res @nvtx_range_debug("chat_response_post_processor") -def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> ChatCompletionResponse: +def chat_response_post_processor( + rsp: GenerationResultBase, + args: ChatPostprocArgs) -> ChatCompletionResponse: choices: List[ChatCompletionResponseChoice] = [] role = args.role for output in rsp.outputs: _, text, reasoning_text = apply_reasoning_parser( args, output.index, output.text, False) - if args.tool_choice and isinstance( - args.tool_choice, - ChatCompletionNamedToolChoiceParam): + if args.tool_choice and isinstance(args.tool_choice, + ChatCompletionNamedToolChoiceParam): message = ChatMessage( role=role, content="", tool_calls=[ ToolCall(function=FunctionCall( - name=args.tool_choice.function.name, - arguments=text)) + name=args.tool_choice.function.name, arguments=text)) ]) else: if text is None: text = "" - message = ChatMessage( - role=role, content=text, reasoning_content=reasoning_text) - disaggregated_params = to_disaggregated_params(output.disaggregated_params) + message = ChatMessage(role=role, + content=text, + reasoning_content=reasoning_text) + disaggregated_params = to_disaggregated_params( + output.disaggregated_params) choice = ChatCompletionResponseChoice( index=output.index, message=message, finish_reason=output.finish_reason, stop_reason=output.stop_reason, disaggregated_params=disaggregated_params, - avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), + avg_decoded_tokens_per_iter=getattr(rsp, + 'avg_decoded_tokens_per_iter', + None), ) if args.return_logprobs: - choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, output.logprobs) + choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, + output.logprobs) choices.append(choice) if args.echo and args.last_message_content: @@ -238,8 +256,7 @@ def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocAr choice.message.content = full_message num_prompt_tokens = args.num_prompt_tokens - num_generated_tokens = sum( - len(output.token_ids) for output in rsp.outputs) + num_generated_tokens = sum(len(output.token_ids) for output in rsp.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, @@ -275,7 +292,8 @@ def from_request(cls, request: CompletionRequest): @nvtx_range_debug("completion_stream_post_processor") -def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: CompletionPostprocArgs) -> List[str]: +def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, + args: CompletionPostprocArgs) -> List[str]: res: List[str] = [] prompt_tokens = args.num_prompt_tokens if stream_option := args.stream_options: @@ -293,9 +311,11 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: index=args.prompt_idx * args.num_choices + output.index, text=delta_text if args.detokenize else "", token_ids=None if args.detokenize else output.token_ids_diff, - finish_reason = output.finish_reason, - stop_reason = output.stop_reason, - avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), + finish_reason=output.finish_reason, + stop_reason=output.stop_reason, + avg_decoded_tokens_per_iter=getattr(rsp, + 'avg_decoded_tokens_per_iter', + None), ) chunk = CompletionStreamResponse(model=args.model, choices=[choice]) if include_continuous_usage: @@ -306,16 +326,16 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: res.append(f"data: {data}\n\n") if include_usage and rsp._done: - completion_tokens = sum(output.length - for output in rsp.outputs) + completion_tokens = sum(output.length for output in rsp.outputs) final_usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) - final_usage_chunk = ChatCompletionStreamResponse( - choices=[], model=args.model, usage=final_usage) + final_usage_chunk = ChatCompletionStreamResponse(choices=[], + model=args.model, + usage=final_usage) final_usage_data = final_usage_chunk.model_dump_json() res.append(f"data: {final_usage_data}\n\n") args.first_iteration = False @@ -323,7 +343,9 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: @nvtx_range_debug("completion_response_post_processor") -def completion_response_post_processor(rsp: GenerationResult, args: CompletionPostprocArgs) -> CompletionResponse: +def completion_response_post_processor( + rsp: GenerationResult, + args: CompletionPostprocArgs) -> CompletionResponse: prompt_tokens = args.num_prompt_tokens completion_tokens = 0 choices = [] @@ -331,23 +353,75 @@ def completion_response_post_processor(rsp: GenerationResult, args: CompletionPo text = output.text if args.echo: text = args.prompt + text - disaggregated_params = to_disaggregated_params(output.disaggregated_params) + disaggregated_params = to_disaggregated_params( + output.disaggregated_params) choice = CompletionResponseChoice( text=text if args.detokenize else "", token_ids=None if args.detokenize else output.token_ids, index=args.prompt_idx * args.num_choices + output.index, disaggregated_params=disaggregated_params, - context_logits=None if rsp.context_logits is None else rsp.context_logits.tolist(), + context_logits=None + if rsp.context_logits is None else rsp.context_logits.tolist(), stop_reason=output.stop_reason, finish_reason=output.finish_reason, - avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), + avg_decoded_tokens_per_iter=getattr(rsp, + 'avg_decoded_tokens_per_iter', + None), ) completion_tokens += output.length choices.append(choice) usage = UsageInfo(prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=completion_tokens + prompt_tokens) - response = CompletionResponse(choices=choices, model=args.model, usage=usage) + completion_tokens=completion_tokens, + total_tokens=completion_tokens + prompt_tokens) + response = CompletionResponse(choices=choices, + model=args.model, + usage=usage) + return response + + +@dataclass(kw_only=True) +class ChatCompletionPostprocArgs(PostprocArgs): + model: str + tools: Optional[List[ChatCompletionToolsParam]] + tool_choice: Optional[Union[Literal["none", "auto"], + ChatCompletionNamedToolChoiceParam]] + request_id: Optional[int] = None + + @classmethod + def from_request(cls, request: ChatCompletionRequest): + return cls( + model=request.model, + tools=request.tools, + tool_choice=request.tool_choice, + ) + + +@nvtx_range_debug("chat_harmony_post_processor") +def chat_harmony_post_processor( + rsp: GenerationResult, + args: ChatCompletionPostprocArgs) -> ChatCompletionResponse: + response = handle_non_streaming_response( + tools=args.tools, + tool_choice=args.tool_choice, + outputs=rsp.outputs, + model=args.model, + num_prompt_tokens=args.num_prompt_tokens, + ) + return response + + +@nvtx_range_debug("chat_harmony_streaming_post_processor") +def chat_harmony_streaming_post_processor( + rsp: GenerationResult, args: ChatCompletionPostprocArgs) -> List[str]: + response = handle_streaming_response( + tools=args.tools, + tool_choice=args.tool_choice, + outputs=rsp.outputs, + model=args.model, + request_id=args.request_id, + done=rsp._done, + num_prompt_tokens=args.num_prompt_tokens, + ) return response diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py b/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py index 0204a04acff..ba6c7d53379 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py @@ -147,6 +147,10 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str): collected_chunks = [] collected_messages = [] async for chunk in response: + # Last streaming response will only contains usage info + if len(chunk.choices) <= 0: + continue + collected_chunks.append(chunk) collected_messages.append(chunk.choices[0].delta) @@ -198,6 +202,10 @@ async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str): reasoning_chunks: list[str] = [] tool_arg_chunks: list[str] = [] async for chunk in response: + # Last streaming response will only contains usage info + if len(chunk.choices) <= 0: + continue + delta = chunk.choices[0].delta if hasattr(delta, "tool_calls") and delta.tool_calls: function = delta.tool_calls[0].function From b38bd4afd15dc7dd5ca1f0f348bde340e104a167 Mon Sep 17 00:00:00 2001 From: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com> Date: Mon, 8 Sep 2025 03:30:34 +0000 Subject: [PATCH 3/3] Enable multiple postprocess workers tests for chat completions api Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com> --- .../llmapi/apps/_test_openai_chat_harmony.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py b/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py index ba6c7d53379..575cd2f0f13 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py @@ -14,10 +14,18 @@ def model(): return "gpt_oss/gpt-oss-20b/" +@pytest.fixture(scope="module", + params=[0, 2], + ids=["disable_processpool", "enable_processpool"]) +def num_postprocess_workers(request): + return request.param + + @pytest.fixture(scope="module") -def server(model: str): +def server(model: str, num_postprocess_workers: int): model_path = get_model_path(model) - with RemoteOpenAIServer(model_path) as remote_server: + args = ["--num_postprocess_workers", f"{num_postprocess_workers}"] + with RemoteOpenAIServer(model_path, args) as remote_server: yield remote_server