Skip to content

Commit ac0df0a

Browse files
JunyiXu-nvLinPolylitaotju
authored
[None][feat] Cherry-pick Responses API and multiple postprocess workers support for chat harmony (#7600)
Signed-off-by: Junyi Xu <[email protected]> Co-authored-by: Pengyun Lin <[email protected]> Co-authored-by: Tao Li @ NVIDIA <[email protected]>
1 parent d60dad6 commit ac0df0a

File tree

9 files changed

+1764
-136
lines changed

9 files changed

+1764
-136
lines changed

tensorrt_llm/serve/harmony_adapter.py

Lines changed: 163 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@
66
import time
77
import traceback
88
import uuid
9-
from typing import Any, AsyncGenerator, Literal
9+
from typing import Any, List, Literal
1010

1111
from openai_harmony import (Author, Conversation, DeveloperContent,
1212
HarmonyEncodingName, HarmonyError, Message,
1313
ReasoningEffort, Role, StreamableParser,
1414
SystemContent, TextContent, ToolDescription,
1515
load_harmony_encoding)
1616

17-
from tensorrt_llm.llmapi import RequestOutput
1817
from tensorrt_llm.logger import logger
1918

2019
# yapf: disable
21-
from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionRequest,
20+
from .openai_protocol import (ChatCompletionMessageParam,
2221
ChatCompletionResponse,
2322
ChatCompletionResponseChoice,
2423
ChatCompletionResponseStreamChoice,
25-
ChatCompletionStreamResponse, ChatMessage,
24+
ChatCompletionStreamResponse,
25+
ChatCompletionToolsParam, ChatMessage,
2626
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
2727
UsageInfo)
2828

@@ -57,7 +57,8 @@ def __init__(self,
5757
# Normal case: filter based on available tools
5858
self.should_filter_tools = True
5959
self.available_tools = {
60-
tool.get("function", {}).get("name", "")
60+
tool.get("function", {}).get("name", "") if tool.get(
61+
"name", None) is None else tool.get("name")
6162
for tool in available_tools
6263
}
6364
self.available_tools.discard("")
@@ -78,6 +79,9 @@ def __init__(self,
7879

7980
logger.debug("Created HarmonyStreamState for request %s", request_id)
8081

82+
def get_parser(self) -> StreamableParser:
83+
return self.parser
84+
8185
def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
8286
"""
8387
Process a batch of tokens while maintaining parsing state.
@@ -125,6 +129,42 @@ def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
125129

126130
return deltas
127131

132+
def process_token_batch_to_messages(self,
133+
tokens: list[int]) -> list[Message]:
134+
"""
135+
Process a batch of tokens while maintaining parsing state.
136+
Returns OpenAI Messages for Responses API
137+
"""
138+
self.tokens_processed += len(tokens)
139+
140+
for token in tokens:
141+
# Store previous state for transition detection
142+
prev_channel = self.parser.current_channel
143+
prev_recipient = self.parser.current_recipient
144+
145+
# Process the token
146+
self.parser.process(token)
147+
148+
# Detect channel/recipient transitions AFTER processing each token
149+
channel_changed = prev_channel != self.parser.current_channel
150+
recipient_changed = prev_recipient != self.parser.current_recipient
151+
152+
if channel_changed or recipient_changed:
153+
# Mark any active tool calls as completed if we're leaving a tool call
154+
if prev_channel == "commentary" and prev_recipient and "functions." in str(
155+
prev_recipient):
156+
func_name = str(prev_recipient).split("functions.")[-1]
157+
for tool_id, tool_info in self.tool_calls.items():
158+
if tool_info["name"] == func_name and tool_info.get(
159+
"active", True):
160+
tool_info["active"] = False
161+
162+
# Reset channel state for new channel
163+
self.channel_started = False
164+
self.current_channel_state = None
165+
166+
return self.parser.messages
167+
128168
def _create_closing_token_delta(self) -> dict[str, Any] | None:
129169
"""Create closing token delta for channel transition."""
130170
if not self.current_channel_state or not self.channel_started:
@@ -317,6 +357,9 @@ def __init__(
317357
"<|constrain|>": 200009,
318358
}
319359

360+
def get_stream_state(self, request_id: str) -> HarmonyStreamState | None:
361+
return self._stream_states.get(request_id, None)
362+
320363
def get_stop_tokens(self) -> list[int]:
321364
"""
322365
Return the list of stop token IDs for Harmony format.
@@ -1214,6 +1257,42 @@ def stateful_stream_harmony_tokens_to_openai_deltas(
12141257
# Return empty deltas to continue processing
12151258
return []
12161259

1260+
def stateful_stream_harmony_tokens_to_openai_messages(
1261+
self,
1262+
request_id: str,
1263+
tokens: list[int],
1264+
available_tools: list[dict[str, Any]] | None = None,
1265+
tool_choice: str | None = None) -> list[Message]:
1266+
"""
1267+
Process tokens using stateful parsing.
1268+
1269+
This method maintains persistent state across multiple calls for the same request,
1270+
ensuring proper channel transitions and tool call handling.
1271+
1272+
Args:
1273+
request_id: Request ID to maintain state per request
1274+
tokens: New tokens from this iteration
1275+
available_tools: Available tools for filtering
1276+
1277+
Returns:
1278+
List of OpenAI Messages
1279+
"""
1280+
stream_state = self._stream_states.get(request_id, None)
1281+
if stream_state is None:
1282+
stream_state = self.create_stream_state(request_id, available_tools,
1283+
tool_choice)
1284+
1285+
try:
1286+
messages = stream_state.process_token_batch_to_messages(tokens)
1287+
return messages
1288+
except (HarmonyError, UnicodeDecodeError, ValueError):
1289+
logger.error(
1290+
f"Streaming: Failed to process token batch of {len(tokens)} tokens for request {request_id}",
1291+
)
1292+
logger.debug(f"Problematic streaming tokens: {tokens}")
1293+
1294+
return []
1295+
12171296
def create_openai_streaming_response(
12181297
self,
12191298
request_id: str,
@@ -1406,36 +1485,72 @@ def _is_tool_call_allowed(self, tool_call: dict[str, Any],
14061485
return True
14071486

14081487

1409-
async def handle_streaming_response(
1410-
harmony_adapter: HarmonyAdapter,
1411-
generator: RequestOutput,
1412-
request_id: str,
1413-
request: ChatCompletionRequest,
1414-
) -> AsyncGenerator[str, None]:
1415-
"""Handle streaming response with harmony format."""
1488+
_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None
1489+
1490+
1491+
def get_harmony_adapter():
1492+
global _SERVE_HARMONY_ADAPTER
1493+
if _SERVE_HARMONY_ADAPTER is None:
1494+
_SERVE_HARMONY_ADAPTER = HarmonyAdapter()
1495+
1496+
return _SERVE_HARMONY_ADAPTER
1497+
1498+
1499+
def handle_streaming_response(tools: List[ChatCompletionToolsParam],
1500+
tool_choice: str, outputs: List, model: str,
1501+
request_id: str, done: bool,
1502+
num_prompt_tokens: int):
14161503
first_iteration = True
1417-
async for res in generator:
1418-
output = res.outputs[0]
1504+
output = outputs[0]
1505+
1506+
# Convert tools to dictionary format for harmony adapter (standard pattern)
1507+
tools_dict = None
1508+
harmony_adapter = get_harmony_adapter()
1509+
if tools:
1510+
tools_dict = [tool.model_dump() for tool in tools]
14191511

1420-
# Convert tools to dictionary format for harmony adapter (standard pattern)
1421-
tools_dict = None
1422-
if request.tools:
1423-
tools_dict = [tool.model_dump() for tool in request.tools]
1512+
# Get tool_choice from request - if "none", don't pass tools to parser
1513+
if tool_choice == "none":
1514+
tools_for_parser = None
1515+
else:
1516+
tools_for_parser = tools_dict
14241517

1425-
# Get tool_choice from request - if "none", don't pass tools to parser
1426-
tool_choice = getattr(request, 'tool_choice', None)
1427-
if tool_choice == "none":
1428-
tools_for_parser = None
1429-
else:
1430-
tools_for_parser = tools_dict
1518+
# Create OpenAI streaming responses
1519+
try:
1520+
res = []
1521+
if done:
1522+
# Clean up state
1523+
harmony_adapter.cleanup_stream_state(request_id)
14311524

1432-
# Create OpenAI streaming responses
1433-
try:
1525+
usage_info = _create_usage_info(num_prompt_tokens, outputs)
1526+
1527+
# Send final message with finish_reason
1528+
final_response = ChatCompletionStreamResponse(
1529+
model=model,
1530+
choices=[
1531+
ChatCompletionResponseStreamChoice(
1532+
index=0,
1533+
delta=DeltaMessage(),
1534+
finish_reason=output.finish_reason,
1535+
stop_reason=output.stop_reason)
1536+
],
1537+
)
1538+
1539+
final_response_json = final_response.model_dump_json(
1540+
exclude_none=True)
1541+
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
1542+
model=model,
1543+
usage=usage_info)
1544+
final_usage_json = final_usage_chunk.model_dump_json(
1545+
exclude_none=True)
1546+
res.append(f"data: {final_response_json}\n\n")
1547+
res.append(f"data: {final_usage_json}\n\n")
1548+
else:
14341549
responses = harmony_adapter.create_openai_streaming_response(
14351550
request_id=request_id,
14361551
tokens=output.token_ids_diff,
14371552
available_tools=tools_for_parser,
1438-
model_name=request.model,
1553+
model_name=model,
14391554
tool_choice=tool_choice)
14401555
# Send first response after receiving the first output
14411556
if first_iteration:
@@ -1446,64 +1561,44 @@ async def handle_streaming_response(
14461561
delta=first_delta)
14471562

14481563
first_response = ChatCompletionStreamResponse(
1449-
model=request.model,
1564+
model=model,
14501565
choices=[choice],
14511566
)
14521567

14531568
response_json = first_response.model_dump_json(
14541569
exclude_none=True)
1455-
yield f"data: {response_json}\n\n"
1456-
1457-
for response in responses:
1458-
yield response
1459-
1460-
except Exception as e:
1461-
logger.error(f"Failed to create OpenAI streaming response: {e}")
1462-
logger.debug(f"Streaming error details: {traceback.format_exc()}")
1463-
# Clean up state
1464-
harmony_adapter.cleanup_stream_state(request_id)
1465-
raise e
1570+
res.append(f"data: {response_json}\n\n")
14661571

1467-
# Clean up state
1468-
harmony_adapter.cleanup_stream_state(request_id)
1572+
res.extend(responses)
14691573

1470-
# Send final message with finish_reason
1471-
output = generator.outputs[0]
1472-
final_response = ChatCompletionStreamResponse(
1473-
model=request.model,
1474-
choices=[
1475-
ChatCompletionResponseStreamChoice(
1476-
index=0,
1477-
delta=DeltaMessage(),
1478-
finish_reason=output.finish_reason,
1479-
stop_reason=output.stop_reason)
1480-
])
1574+
return res
14811575

1482-
yield f"data: {final_response.model_dump_json(exclude_unset=True)}\n\n"
1483-
yield "data: [DONE]\n\n"
1576+
except Exception as e:
1577+
logger.error(f"Failed to create OpenAI streaming response: {e}")
1578+
logger.debug(f"Streaming error details: {traceback.format_exc()}")
1579+
# Clean up state
1580+
harmony_adapter.cleanup_stream_state(request_id)
1581+
raise e
14841582

14851583

1486-
async def handle_non_streaming_response(
1487-
harmony_adapter: HarmonyAdapter, promise: RequestOutput,
1488-
request: ChatCompletionRequest) -> ChatCompletionResponse:
1584+
def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
1585+
tool_choice: str, outputs: List, model: str,
1586+
num_prompt_tokens: int):
14891587
"""Handle non-streaming response with harmony format."""
1490-
# Get final result
1491-
await promise
1492-
14931588
# Parse harmony output to OpenAI format
14941589
# Convert tools to dictionary format for harmony adapter (standard pattern)
14951590
tools_dict = None
1496-
if request.tools:
1497-
tools_dict = [tool.model_dump() for tool in request.tools]
1591+
harmony_adapter = get_harmony_adapter()
1592+
if tools:
1593+
tools_dict = [tool.model_dump() for tool in tools]
14981594

14991595
# Get tool_choice from request - if "none", don't pass tools to parser
1500-
tool_choice = getattr(request, 'tool_choice', None)
15011596
if tool_choice == "none":
15021597
tools_for_parser = None
15031598
else:
15041599
tools_for_parser = tools_dict
15051600

1506-
output = promise.outputs[0]
1601+
output = outputs[0]
15071602
parsed_output = harmony_adapter.harmony_output_to_openai(
15081603
output.token_ids, tools_for_parser, tool_choice)
15091604

@@ -1518,11 +1613,11 @@ async def handle_non_streaming_response(
15181613
output.finish_reason)
15191614

15201615
# Create usage info from metrics (RequestOutput doesn't have usage in v1)
1521-
usage_info = _create_usage_info(promise)
1616+
usage_info = _create_usage_info(num_prompt_tokens, outputs)
15221617

15231618
# Create response
15241619
response = ChatCompletionResponse(
1525-
model=request.model,
1620+
model=model,
15261621
choices=[
15271622
ChatCompletionResponseChoice(
15281623
index=0,
@@ -1534,7 +1629,6 @@ async def handle_non_streaming_response(
15341629
# Optional: Log if harmony parsing failed (for debugging)
15351630
if parsed_output.get('_harmony_parsing_failed'):
15361631
logger.warning("⚠️ Harmony parsing fell back to raw text decoding")
1537-
logger.debug(f"request\n\n{request}")
15381632
logger.debug(f"response\n\n{response}\n")
15391633

15401634
return response
@@ -1567,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any],
15671661
return reason
15681662

15691663

1570-
def _create_usage_info(final_res: RequestOutput) -> UsageInfo:
1664+
def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo:
15711665
"""Create usage info from RequestOutput following serving_chat.py pattern."""
1572-
# Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids
1573-
assert final_res.prompt_token_ids is not None
1574-
num_prompt_tokens = len(final_res.prompt_token_ids)
1575-
15761666
# Calculate completion tokens from all outputs
1577-
num_generated_tokens = sum(
1578-
len(output.token_ids) for output in final_res.outputs)
1667+
num_generated_tokens = sum(len(output.token_ids) for output in outputs)
15791668

15801669
# Create usage info
15811670
usage = UsageInfo(prompt_tokens=num_prompt_tokens,

0 commit comments

Comments
 (0)