6
6
import time
7
7
import traceback
8
8
import uuid
9
- from typing import Any , AsyncGenerator , Literal
9
+ from typing import Any , List , Literal
10
10
11
11
from openai_harmony import (Author , Conversation , DeveloperContent ,
12
12
HarmonyEncodingName , HarmonyError , Message ,
13
13
ReasoningEffort , Role , StreamableParser ,
14
14
SystemContent , TextContent , ToolDescription ,
15
15
load_harmony_encoding )
16
16
17
- from tensorrt_llm .llmapi import RequestOutput
18
17
from tensorrt_llm .logger import logger
19
18
20
19
# yapf: disable
21
- from .openai_protocol import (ChatCompletionMessageParam , ChatCompletionRequest ,
20
+ from .openai_protocol import (ChatCompletionMessageParam ,
22
21
ChatCompletionResponse ,
23
22
ChatCompletionResponseChoice ,
24
23
ChatCompletionResponseStreamChoice ,
25
- ChatCompletionStreamResponse , ChatMessage ,
24
+ ChatCompletionStreamResponse ,
25
+ ChatCompletionToolsParam , ChatMessage ,
26
26
DeltaFunctionCall , DeltaMessage , DeltaToolCall ,
27
27
UsageInfo )
28
28
@@ -57,7 +57,8 @@ def __init__(self,
57
57
# Normal case: filter based on available tools
58
58
self .should_filter_tools = True
59
59
self .available_tools = {
60
- tool .get ("function" , {}).get ("name" , "" )
60
+ tool .get ("function" , {}).get ("name" , "" ) if tool .get (
61
+ "name" , None ) is None else tool .get ("name" )
61
62
for tool in available_tools
62
63
}
63
64
self .available_tools .discard ("" )
@@ -78,6 +79,9 @@ def __init__(self,
78
79
79
80
logger .debug ("Created HarmonyStreamState for request %s" , request_id )
80
81
82
+ def get_parser (self ) -> StreamableParser :
83
+ return self .parser
84
+
81
85
def process_token_batch (self , tokens : list [int ]) -> list [dict [str , Any ]]:
82
86
"""
83
87
Process a batch of tokens while maintaining parsing state.
@@ -125,6 +129,42 @@ def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
125
129
126
130
return deltas
127
131
132
+ def process_token_batch_to_messages (self ,
133
+ tokens : list [int ]) -> list [Message ]:
134
+ """
135
+ Process a batch of tokens while maintaining parsing state.
136
+ Returns OpenAI Messages for Responses API
137
+ """
138
+ self .tokens_processed += len (tokens )
139
+
140
+ for token in tokens :
141
+ # Store previous state for transition detection
142
+ prev_channel = self .parser .current_channel
143
+ prev_recipient = self .parser .current_recipient
144
+
145
+ # Process the token
146
+ self .parser .process (token )
147
+
148
+ # Detect channel/recipient transitions AFTER processing each token
149
+ channel_changed = prev_channel != self .parser .current_channel
150
+ recipient_changed = prev_recipient != self .parser .current_recipient
151
+
152
+ if channel_changed or recipient_changed :
153
+ # Mark any active tool calls as completed if we're leaving a tool call
154
+ if prev_channel == "commentary" and prev_recipient and "functions." in str (
155
+ prev_recipient ):
156
+ func_name = str (prev_recipient ).split ("functions." )[- 1 ]
157
+ for tool_id , tool_info in self .tool_calls .items ():
158
+ if tool_info ["name" ] == func_name and tool_info .get (
159
+ "active" , True ):
160
+ tool_info ["active" ] = False
161
+
162
+ # Reset channel state for new channel
163
+ self .channel_started = False
164
+ self .current_channel_state = None
165
+
166
+ return self .parser .messages
167
+
128
168
def _create_closing_token_delta (self ) -> dict [str , Any ] | None :
129
169
"""Create closing token delta for channel transition."""
130
170
if not self .current_channel_state or not self .channel_started :
@@ -317,6 +357,9 @@ def __init__(
317
357
"<|constrain|>" : 200009 ,
318
358
}
319
359
360
+ def get_stream_state (self , request_id : str ) -> HarmonyStreamState | None :
361
+ return self ._stream_states .get (request_id , None )
362
+
320
363
def get_stop_tokens (self ) -> list [int ]:
321
364
"""
322
365
Return the list of stop token IDs for Harmony format.
@@ -1214,6 +1257,42 @@ def stateful_stream_harmony_tokens_to_openai_deltas(
1214
1257
# Return empty deltas to continue processing
1215
1258
return []
1216
1259
1260
+ def stateful_stream_harmony_tokens_to_openai_messages (
1261
+ self ,
1262
+ request_id : str ,
1263
+ tokens : list [int ],
1264
+ available_tools : list [dict [str , Any ]] | None = None ,
1265
+ tool_choice : str | None = None ) -> list [Message ]:
1266
+ """
1267
+ Process tokens using stateful parsing.
1268
+
1269
+ This method maintains persistent state across multiple calls for the same request,
1270
+ ensuring proper channel transitions and tool call handling.
1271
+
1272
+ Args:
1273
+ request_id: Request ID to maintain state per request
1274
+ tokens: New tokens from this iteration
1275
+ available_tools: Available tools for filtering
1276
+
1277
+ Returns:
1278
+ List of OpenAI Messages
1279
+ """
1280
+ stream_state = self ._stream_states .get (request_id , None )
1281
+ if stream_state is None :
1282
+ stream_state = self .create_stream_state (request_id , available_tools ,
1283
+ tool_choice )
1284
+
1285
+ try :
1286
+ messages = stream_state .process_token_batch_to_messages (tokens )
1287
+ return messages
1288
+ except (HarmonyError , UnicodeDecodeError , ValueError ):
1289
+ logger .error (
1290
+ f"Streaming: Failed to process token batch of { len (tokens )} tokens for request { request_id } " ,
1291
+ )
1292
+ logger .debug (f"Problematic streaming tokens: { tokens } " )
1293
+
1294
+ return []
1295
+
1217
1296
def create_openai_streaming_response (
1218
1297
self ,
1219
1298
request_id : str ,
@@ -1406,36 +1485,72 @@ def _is_tool_call_allowed(self, tool_call: dict[str, Any],
1406
1485
return True
1407
1486
1408
1487
1409
- async def handle_streaming_response (
1410
- harmony_adapter : HarmonyAdapter ,
1411
- generator : RequestOutput ,
1412
- request_id : str ,
1413
- request : ChatCompletionRequest ,
1414
- ) -> AsyncGenerator [str , None ]:
1415
- """Handle streaming response with harmony format."""
1488
+ _SERVE_HARMONY_ADAPTER : HarmonyAdapter = None
1489
+
1490
+
1491
+ def get_harmony_adapter ():
1492
+ global _SERVE_HARMONY_ADAPTER
1493
+ if _SERVE_HARMONY_ADAPTER is None :
1494
+ _SERVE_HARMONY_ADAPTER = HarmonyAdapter ()
1495
+
1496
+ return _SERVE_HARMONY_ADAPTER
1497
+
1498
+
1499
+ def handle_streaming_response (tools : List [ChatCompletionToolsParam ],
1500
+ tool_choice : str , outputs : List , model : str ,
1501
+ request_id : str , done : bool ,
1502
+ num_prompt_tokens : int ):
1416
1503
first_iteration = True
1417
- async for res in generator :
1418
- output = res .outputs [0 ]
1504
+ output = outputs [0 ]
1505
+
1506
+ # Convert tools to dictionary format for harmony adapter (standard pattern)
1507
+ tools_dict = None
1508
+ harmony_adapter = get_harmony_adapter ()
1509
+ if tools :
1510
+ tools_dict = [tool .model_dump () for tool in tools ]
1419
1511
1420
- # Convert tools to dictionary format for harmony adapter (standard pattern)
1421
- tools_dict = None
1422
- if request .tools :
1423
- tools_dict = [tool .model_dump () for tool in request .tools ]
1512
+ # Get tool_choice from request - if "none", don't pass tools to parser
1513
+ if tool_choice == "none" :
1514
+ tools_for_parser = None
1515
+ else :
1516
+ tools_for_parser = tools_dict
1424
1517
1425
- # Get tool_choice from request - if "none", don't pass tools to parser
1426
- tool_choice = getattr ( request , 'tool_choice' , None )
1427
- if tool_choice == "none" :
1428
- tools_for_parser = None
1429
- else :
1430
- tools_for_parser = tools_dict
1518
+ # Create OpenAI streaming responses
1519
+ try :
1520
+ res = []
1521
+ if done :
1522
+ # Clean up state
1523
+ harmony_adapter . cleanup_stream_state ( request_id )
1431
1524
1432
- # Create OpenAI streaming responses
1433
- try :
1525
+ usage_info = _create_usage_info (num_prompt_tokens , outputs )
1526
+
1527
+ # Send final message with finish_reason
1528
+ final_response = ChatCompletionStreamResponse (
1529
+ model = model ,
1530
+ choices = [
1531
+ ChatCompletionResponseStreamChoice (
1532
+ index = 0 ,
1533
+ delta = DeltaMessage (),
1534
+ finish_reason = output .finish_reason ,
1535
+ stop_reason = output .stop_reason )
1536
+ ],
1537
+ )
1538
+
1539
+ final_response_json = final_response .model_dump_json (
1540
+ exclude_none = True )
1541
+ final_usage_chunk = ChatCompletionStreamResponse (choices = [],
1542
+ model = model ,
1543
+ usage = usage_info )
1544
+ final_usage_json = final_usage_chunk .model_dump_json (
1545
+ exclude_none = True )
1546
+ res .append (f"data: { final_response_json } \n \n " )
1547
+ res .append (f"data: { final_usage_json } \n \n " )
1548
+ else :
1434
1549
responses = harmony_adapter .create_openai_streaming_response (
1435
1550
request_id = request_id ,
1436
1551
tokens = output .token_ids_diff ,
1437
1552
available_tools = tools_for_parser ,
1438
- model_name = request . model ,
1553
+ model_name = model ,
1439
1554
tool_choice = tool_choice )
1440
1555
# Send first response after receiving the first output
1441
1556
if first_iteration :
@@ -1446,64 +1561,44 @@ async def handle_streaming_response(
1446
1561
delta = first_delta )
1447
1562
1448
1563
first_response = ChatCompletionStreamResponse (
1449
- model = request . model ,
1564
+ model = model ,
1450
1565
choices = [choice ],
1451
1566
)
1452
1567
1453
1568
response_json = first_response .model_dump_json (
1454
1569
exclude_none = True )
1455
- yield f"data: { response_json } \n \n "
1456
-
1457
- for response in responses :
1458
- yield response
1459
-
1460
- except Exception as e :
1461
- logger .error (f"Failed to create OpenAI streaming response: { e } " )
1462
- logger .debug (f"Streaming error details: { traceback .format_exc ()} " )
1463
- # Clean up state
1464
- harmony_adapter .cleanup_stream_state (request_id )
1465
- raise e
1570
+ res .append (f"data: { response_json } \n \n " )
1466
1571
1467
- # Clean up state
1468
- harmony_adapter .cleanup_stream_state (request_id )
1572
+ res .extend (responses )
1469
1573
1470
- # Send final message with finish_reason
1471
- output = generator .outputs [0 ]
1472
- final_response = ChatCompletionStreamResponse (
1473
- model = request .model ,
1474
- choices = [
1475
- ChatCompletionResponseStreamChoice (
1476
- index = 0 ,
1477
- delta = DeltaMessage (),
1478
- finish_reason = output .finish_reason ,
1479
- stop_reason = output .stop_reason )
1480
- ])
1574
+ return res
1481
1575
1482
- yield f"data: { final_response .model_dump_json (exclude_unset = True )} \n \n "
1483
- yield "data: [DONE]\n \n "
1576
+ except Exception as e :
1577
+ logger .error (f"Failed to create OpenAI streaming response: { e } " )
1578
+ logger .debug (f"Streaming error details: { traceback .format_exc ()} " )
1579
+ # Clean up state
1580
+ harmony_adapter .cleanup_stream_state (request_id )
1581
+ raise e
1484
1582
1485
1583
1486
- async def handle_non_streaming_response (
1487
- harmony_adapter : HarmonyAdapter , promise : RequestOutput ,
1488
- request : ChatCompletionRequest ) -> ChatCompletionResponse :
1584
+ def handle_non_streaming_response (tools : List [ ChatCompletionToolsParam ],
1585
+ tool_choice : str , outputs : List , model : str ,
1586
+ num_prompt_tokens : int ) :
1489
1587
"""Handle non-streaming response with harmony format."""
1490
- # Get final result
1491
- await promise
1492
-
1493
1588
# Parse harmony output to OpenAI format
1494
1589
# Convert tools to dictionary format for harmony adapter (standard pattern)
1495
1590
tools_dict = None
1496
- if request .tools :
1497
- tools_dict = [tool .model_dump () for tool in request .tools ]
1591
+ harmony_adapter = get_harmony_adapter ()
1592
+ if tools :
1593
+ tools_dict = [tool .model_dump () for tool in tools ]
1498
1594
1499
1595
# Get tool_choice from request - if "none", don't pass tools to parser
1500
- tool_choice = getattr (request , 'tool_choice' , None )
1501
1596
if tool_choice == "none" :
1502
1597
tools_for_parser = None
1503
1598
else :
1504
1599
tools_for_parser = tools_dict
1505
1600
1506
- output = promise . outputs [0 ]
1601
+ output = outputs [0 ]
1507
1602
parsed_output = harmony_adapter .harmony_output_to_openai (
1508
1603
output .token_ids , tools_for_parser , tool_choice )
1509
1604
@@ -1518,11 +1613,11 @@ async def handle_non_streaming_response(
1518
1613
output .finish_reason )
1519
1614
1520
1615
# Create usage info from metrics (RequestOutput doesn't have usage in v1)
1521
- usage_info = _create_usage_info (promise )
1616
+ usage_info = _create_usage_info (num_prompt_tokens , outputs )
1522
1617
1523
1618
# Create response
1524
1619
response = ChatCompletionResponse (
1525
- model = request . model ,
1620
+ model = model ,
1526
1621
choices = [
1527
1622
ChatCompletionResponseChoice (
1528
1623
index = 0 ,
@@ -1534,7 +1629,6 @@ async def handle_non_streaming_response(
1534
1629
# Optional: Log if harmony parsing failed (for debugging)
1535
1630
if parsed_output .get ('_harmony_parsing_failed' ):
1536
1631
logger .warning ("⚠️ Harmony parsing fell back to raw text decoding" )
1537
- logger .debug (f"request\n \n { request } " )
1538
1632
logger .debug (f"response\n \n { response } \n " )
1539
1633
1540
1634
return response
@@ -1567,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any],
1567
1661
return reason
1568
1662
1569
1663
1570
- def _create_usage_info (final_res : RequestOutput ) -> UsageInfo :
1664
+ def _create_usage_info (num_prompt_tokens , outputs ) -> UsageInfo :
1571
1665
"""Create usage info from RequestOutput following serving_chat.py pattern."""
1572
- # Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids
1573
- assert final_res .prompt_token_ids is not None
1574
- num_prompt_tokens = len (final_res .prompt_token_ids )
1575
-
1576
1666
# Calculate completion tokens from all outputs
1577
- num_generated_tokens = sum (
1578
- len (output .token_ids ) for output in final_res .outputs )
1667
+ num_generated_tokens = sum (len (output .token_ids ) for output in outputs )
1579
1668
1580
1669
# Create usage info
1581
1670
usage = UsageInfo (prompt_tokens = num_prompt_tokens ,
0 commit comments