Skip to content

Commit 040a895

Browse files
JunyiXu-nvLinPoly
andcommitted
[TRTLLM-7779][feat] Support multiple postprocess workers for chat completions API (#7508)
Signed-off-by: Junyi Xu <[email protected]> Co-authored-by: Pengyun Lin <[email protected]>
1 parent 56d3547 commit 040a895

File tree

4 files changed

+261
-131
lines changed

4 files changed

+261
-131
lines changed

tensorrt_llm/serve/harmony_adapter.py

Lines changed: 83 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@
66
import time
77
import traceback
88
import uuid
9-
from typing import Any, AsyncGenerator, Literal
9+
from typing import Any, List, Literal
1010

1111
from openai_harmony import (Author, Conversation, DeveloperContent,
1212
HarmonyEncodingName, HarmonyError, Message,
1313
ReasoningEffort, Role, StreamableParser,
1414
SystemContent, TextContent, ToolDescription,
1515
load_harmony_encoding)
1616

17-
from tensorrt_llm.llmapi import RequestOutput
1817
from tensorrt_llm.logger import logger
1918

2019
# yapf: disable
21-
from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionRequest,
20+
from .openai_protocol import (ChatCompletionMessageParam,
2221
ChatCompletionResponse,
2322
ChatCompletionResponseChoice,
2423
ChatCompletionResponseStreamChoice,
25-
ChatCompletionStreamResponse, ChatMessage,
24+
ChatCompletionStreamResponse,
25+
ChatCompletionToolsParam, ChatMessage,
2626
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
2727
UsageInfo)
2828

@@ -1485,36 +1485,72 @@ def _is_tool_call_allowed(self, tool_call: dict[str, Any],
14851485
return True
14861486

14871487

1488-
async def handle_streaming_response(
1489-
harmony_adapter: HarmonyAdapter,
1490-
generator: RequestOutput,
1491-
request_id: str,
1492-
request: ChatCompletionRequest,
1493-
) -> AsyncGenerator[str, None]:
1494-
"""Handle streaming response with harmony format."""
1488+
_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None
1489+
1490+
1491+
def get_harmony_adapter():
1492+
global _SERVE_HARMONY_ADAPTER
1493+
if _SERVE_HARMONY_ADAPTER is None:
1494+
_SERVE_HARMONY_ADAPTER = HarmonyAdapter()
1495+
1496+
return _SERVE_HARMONY_ADAPTER
1497+
1498+
1499+
def handle_streaming_response(tools: List[ChatCompletionToolsParam],
1500+
tool_choice: str, outputs: List, model: str,
1501+
request_id: str, done: bool,
1502+
num_prompt_tokens: int):
14951503
first_iteration = True
1496-
async for res in generator:
1497-
output = res.outputs[0]
1504+
output = outputs[0]
14981505

1499-
# Convert tools to dictionary format for harmony adapter (standard pattern)
1500-
tools_dict = None
1501-
if request.tools:
1502-
tools_dict = [tool.model_dump() for tool in request.tools]
1506+
# Convert tools to dictionary format for harmony adapter (standard pattern)
1507+
tools_dict = None
1508+
harmony_adapter = get_harmony_adapter()
1509+
if tools:
1510+
tools_dict = [tool.model_dump() for tool in tools]
15031511

1504-
# Get tool_choice from request - if "none", don't pass tools to parser
1505-
tool_choice = getattr(request, 'tool_choice', None)
1506-
if tool_choice == "none":
1507-
tools_for_parser = None
1508-
else:
1509-
tools_for_parser = tools_dict
1512+
# Get tool_choice from request - if "none", don't pass tools to parser
1513+
if tool_choice == "none":
1514+
tools_for_parser = None
1515+
else:
1516+
tools_for_parser = tools_dict
15101517

1511-
# Create OpenAI streaming responses
1512-
try:
1518+
# Create OpenAI streaming responses
1519+
try:
1520+
res = []
1521+
if done:
1522+
# Clean up state
1523+
harmony_adapter.cleanup_stream_state(request_id)
1524+
1525+
usage_info = _create_usage_info(num_prompt_tokens, outputs)
1526+
1527+
# Send final message with finish_reason
1528+
final_response = ChatCompletionStreamResponse(
1529+
model=model,
1530+
choices=[
1531+
ChatCompletionResponseStreamChoice(
1532+
index=0,
1533+
delta=DeltaMessage(),
1534+
finish_reason=output.finish_reason,
1535+
stop_reason=output.stop_reason)
1536+
],
1537+
)
1538+
1539+
final_response_json = final_response.model_dump_json(
1540+
exclude_none=True)
1541+
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
1542+
model=model,
1543+
usage=usage_info)
1544+
final_usage_json = final_usage_chunk.model_dump_json(
1545+
exclude_none=True)
1546+
res.append(f"data: {final_response_json}\n\n")
1547+
res.append(f"data: {final_usage_json}\n\n")
1548+
else:
15131549
responses = harmony_adapter.create_openai_streaming_response(
15141550
request_id=request_id,
15151551
tokens=output.token_ids_diff,
15161552
available_tools=tools_for_parser,
1517-
model_name=request.model,
1553+
model_name=model,
15181554
tool_choice=tool_choice)
15191555
# Send first response after receiving the first output
15201556
if first_iteration:
@@ -1525,64 +1561,44 @@ async def handle_streaming_response(
15251561
delta=first_delta)
15261562

15271563
first_response = ChatCompletionStreamResponse(
1528-
model=request.model,
1564+
model=model,
15291565
choices=[choice],
15301566
)
15311567

15321568
response_json = first_response.model_dump_json(
15331569
exclude_none=True)
1534-
yield f"data: {response_json}\n\n"
1570+
res.append(f"data: {response_json}\n\n")
15351571

1536-
for response in responses:
1537-
yield response
1572+
res.extend(responses)
15381573

1539-
except Exception as e:
1540-
logger.error(f"Failed to create OpenAI streaming response: {e}")
1541-
logger.debug(f"Streaming error details: {traceback.format_exc()}")
1542-
# Clean up state
1543-
harmony_adapter.cleanup_stream_state(request_id)
1544-
raise e
1545-
1546-
# Clean up state
1547-
harmony_adapter.cleanup_stream_state(request_id)
1574+
return res
15481575

1549-
# Send final message with finish_reason
1550-
output = generator.outputs[0]
1551-
final_response = ChatCompletionStreamResponse(
1552-
model=request.model,
1553-
choices=[
1554-
ChatCompletionResponseStreamChoice(
1555-
index=0,
1556-
delta=DeltaMessage(),
1557-
finish_reason=output.finish_reason,
1558-
stop_reason=output.stop_reason)
1559-
])
1576+
except Exception as e:
1577+
logger.error(f"Failed to create OpenAI streaming response: {e}")
1578+
logger.debug(f"Streaming error details: {traceback.format_exc()}")
1579+
# Clean up state
1580+
harmony_adapter.cleanup_stream_state(request_id)
1581+
raise e
15601582

1561-
yield f"data: {final_response.model_dump_json(exclude_unset=True)}\n\n"
1562-
yield "data: [DONE]\n\n"
15631583

1564-
1565-
async def handle_non_streaming_response(
1566-
harmony_adapter: HarmonyAdapter, promise: RequestOutput,
1567-
request: ChatCompletionRequest) -> ChatCompletionResponse:
1584+
def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
1585+
tool_choice: str, outputs: List, model: str,
1586+
num_prompt_tokens: int):
15681587
"""Handle non-streaming response with harmony format."""
1569-
# Get final result
1570-
await promise
1571-
15721588
# Parse harmony output to OpenAI format
15731589
# Convert tools to dictionary format for harmony adapter (standard pattern)
15741590
tools_dict = None
1575-
if request.tools:
1576-
tools_dict = [tool.model_dump() for tool in request.tools]
1591+
harmony_adapter = get_harmony_adapter()
1592+
if tools:
1593+
tools_dict = [tool.model_dump() for tool in tools]
15771594

15781595
# Get tool_choice from request - if "none", don't pass tools to parser
1579-
tool_choice = getattr(request, 'tool_choice', None)
15801596
if tool_choice == "none":
15811597
tools_for_parser = None
15821598
else:
15831599
tools_for_parser = tools_dict
15841600

1585-
output = promise.outputs[0]
1601+
output = outputs[0]
15861602
parsed_output = harmony_adapter.harmony_output_to_openai(
15871603
output.token_ids, tools_for_parser, tool_choice)
15881604

@@ -1597,11 +1613,11 @@ async def handle_non_streaming_response(
15971613
output.finish_reason)
15981614

15991615
# Create usage info from metrics (RequestOutput doesn't have usage in v1)
1600-
usage_info = _create_usage_info(promise)
1616+
usage_info = _create_usage_info(num_prompt_tokens, outputs)
16011617

16021618
# Create response
16031619
response = ChatCompletionResponse(
1604-
model=request.model,
1620+
model=model,
16051621
choices=[
16061622
ChatCompletionResponseChoice(
16071623
index=0,
@@ -1613,7 +1629,6 @@ async def handle_non_streaming_response(
16131629
# Optional: Log if harmony parsing failed (for debugging)
16141630
if parsed_output.get('_harmony_parsing_failed'):
16151631
logger.warning("⚠️ Harmony parsing fell back to raw text decoding")
1616-
logger.debug(f"request\n\n{request}")
16171632
logger.debug(f"response\n\n{response}\n")
16181633

16191634
return response
@@ -1646,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any],
16461661
return reason
16471662

16481663

1649-
def _create_usage_info(final_res: RequestOutput) -> UsageInfo:
1664+
def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo:
16501665
"""Create usage info from RequestOutput following serving_chat.py pattern."""
1651-
# Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids
1652-
assert final_res.prompt_token_ids is not None
1653-
num_prompt_tokens = len(final_res.prompt_token_ids)
1654-
16551666
# Calculate completion tokens from all outputs
1656-
num_generated_tokens = sum(
1657-
len(output.token_ids) for output in final_res.outputs)
1667+
num_generated_tokens = sum(len(output.token_ids) for output in outputs)
16581668

16591669
# Create usage info
16601670
usage = UsageInfo(prompt_tokens=num_prompt_tokens,

tensorrt_llm/serve/openai_server.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@
4545
UsageInfo,
4646
to_llm_disaggregated_params)
4747
from tensorrt_llm.serve.postprocess_handlers import (
48-
ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor,
49-
chat_stream_post_processor, completion_response_post_processor,
50-
completion_stream_post_processor)
48+
ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs,
49+
chat_harmony_post_processor, chat_harmony_streaming_post_processor,
50+
chat_response_post_processor, chat_stream_post_processor,
51+
completion_response_post_processor, completion_stream_post_processor)
5152
from tensorrt_llm.serve.responses_utils import ConversationHistoryStore
5253
from tensorrt_llm.serve.responses_utils import \
5354
create_response as responses_api_create_response
@@ -58,8 +59,7 @@
5859
from tensorrt_llm.version import __version__ as VERSION
5960

6061
from .._utils import nvtx_mark, set_prometheus_multiproc_dir
61-
from .harmony_adapter import (HarmonyAdapter, handle_non_streaming_response,
62-
handle_streaming_response,
62+
from .harmony_adapter import (HarmonyAdapter, get_harmony_adapter,
6363
maybe_transform_reasoning_effort)
6464

6565
# yapf: enale
@@ -118,7 +118,11 @@ def __init__(self,
118118

119119
# gpt-oss
120120
self.harmony_adapter: HarmonyAdapter | None = None
121-
self.use_harmony = self.model_config.model_type == "gpt_oss"
121+
disable_harmony = os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1"
122+
if disable_harmony:
123+
self.use_harmony = False
124+
else:
125+
self.use_harmony = (self.model_config.model_type == "gpt_oss")
122126

123127
@asynccontextmanager
124128
async def lifespan(app: FastAPI):
@@ -712,11 +716,35 @@ async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Reques
712716
Chat Completion API with harmony format support.
713717
Supports both streaming and non-streaming modes.
714718
"""
719+
720+
async def create_harmony_response(
721+
promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse:
722+
await promise.aresult()
723+
if self.postproc_worker_enabled:
724+
chat_response =promise.outputs[0]._postprocess_result
725+
else:
726+
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
727+
chat_response = post_processor(promise, args)
728+
729+
return chat_response
730+
731+
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
732+
if not self.postproc_worker_enabled:
733+
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
734+
735+
async for res in promise:
736+
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
737+
# await self._extract_metrics(res)
738+
for pp_res in pp_results:
739+
yield pp_res
740+
741+
yield "data: [DONE]\n\n"
742+
715743
try:
716744
# Initialize HarmonyAdapter
717745
# NOTE: WAR for Disagg failure, may affect perf if no warmup
718746
if not self.harmony_adapter:
719-
self.harmony_adapter = HarmonyAdapter()
747+
self.harmony_adapter = get_harmony_adapter()
720748
# Convert Pydantic models to dictionaries for JSON serialization (standard pattern)
721749
tools_dict = None
722750
if request.tools:
@@ -751,27 +779,37 @@ async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Reques
751779
vocab_size=self.tokenizer.tokenizer.vocab_size)
752780
sampling_params.detokenize = False # Harmony adapter handles detokenization
753781

782+
postproc_args = ChatCompletionPostprocArgs.from_request(request)
783+
postproc_params = PostprocParams(
784+
post_processor=chat_harmony_streaming_post_processor
785+
if request.stream else chat_harmony_post_processor,
786+
postproc_args=postproc_args,
787+
)
788+
754789
# Generate
755790
promise = self.llm.generate_async(
756791
inputs=harmony_tokens,
757792
sampling_params=sampling_params,
793+
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
758794
streaming=bool(request.stream),
759795
lora_request=request.lora_request,
760796
)
797+
postproc_args.request_id = promise.request_id
798+
799+
if not self.postproc_worker_enabled:
800+
postproc_args.num_prompt_tokens = len(promise.prompt_token_ids)
801+
761802
# Disconnect cancellation
762803
asyncio.create_task(self.await_disconnected(raw_request, promise))
763804

764805
# Handle streaming
765806
if request.stream:
766807
return StreamingResponse(
767-
handle_streaming_response(
768-
self.harmony_adapter, promise,
769-
str(promise.request_id), request,
770-
),
808+
content=create_streaming_generator(promise, postproc_params),
771809
media_type="text/event-stream"
772810
)
773811
else:
774-
response = await handle_non_streaming_response(self.harmony_adapter, promise, request)
812+
response = await create_harmony_response(promise, postproc_params)
775813
return JSONResponse(response.model_dump())
776814

777815
except Exception as e:

0 commit comments

Comments
 (0)