Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 163 additions & 74 deletions tensorrt_llm/serve/harmony_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@
import time
import traceback
import uuid
from typing import Any, AsyncGenerator, Literal
from typing import Any, List, Literal

from openai_harmony import (Author, Conversation, DeveloperContent,
HarmonyEncodingName, HarmonyError, Message,
ReasoningEffort, Role, StreamableParser,
SystemContent, TextContent, ToolDescription,
load_harmony_encoding)

from tensorrt_llm.llmapi import RequestOutput
from tensorrt_llm.logger import logger

# yapf: disable
from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionRequest,
from .openai_protocol import (ChatCompletionMessageParam,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage,
ChatCompletionStreamResponse,
ChatCompletionToolsParam, ChatMessage,
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
UsageInfo)

Expand Down Expand Up @@ -57,7 +57,8 @@ def __init__(self,
# Normal case: filter based on available tools
self.should_filter_tools = True
self.available_tools = {
tool.get("function", {}).get("name", "")
tool.get("function", {}).get("name", "") if tool.get(
"name", None) is None else tool.get("name")
for tool in available_tools
}
self.available_tools.discard("")
Expand All @@ -78,6 +79,9 @@ def __init__(self,

logger.debug("Created HarmonyStreamState for request %s", request_id)

def get_parser(self) -> StreamableParser:
return self.parser

def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
"""
Process a batch of tokens while maintaining parsing state.
Expand Down Expand Up @@ -125,6 +129,42 @@ def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:

return deltas

def process_token_batch_to_messages(self,
tokens: list[int]) -> list[Message]:
"""
Process a batch of tokens while maintaining parsing state.
Returns OpenAI Messages for Responses API
"""
self.tokens_processed += len(tokens)

for token in tokens:
# Store previous state for transition detection
prev_channel = self.parser.current_channel
prev_recipient = self.parser.current_recipient

# Process the token
self.parser.process(token)

# Detect channel/recipient transitions AFTER processing each token
channel_changed = prev_channel != self.parser.current_channel
recipient_changed = prev_recipient != self.parser.current_recipient

if channel_changed or recipient_changed:
# Mark any active tool calls as completed if we're leaving a tool call
if prev_channel == "commentary" and prev_recipient and "functions." in str(
prev_recipient):
func_name = str(prev_recipient).split("functions.")[-1]
for tool_id, tool_info in self.tool_calls.items():
if tool_info["name"] == func_name and tool_info.get(
"active", True):
tool_info["active"] = False

# Reset channel state for new channel
self.channel_started = False
self.current_channel_state = None

return self.parser.messages

def _create_closing_token_delta(self) -> dict[str, Any] | None:
"""Create closing token delta for channel transition."""
if not self.current_channel_state or not self.channel_started:
Expand Down Expand Up @@ -317,6 +357,9 @@ def __init__(
"<|constrain|>": 200009,
}

def get_stream_state(self, request_id: str) -> HarmonyStreamState | None:
return self._stream_states.get(request_id, None)

def get_stop_tokens(self) -> list[int]:
"""
Return the list of stop token IDs for Harmony format.
Expand Down Expand Up @@ -1214,6 +1257,42 @@ def stateful_stream_harmony_tokens_to_openai_deltas(
# Return empty deltas to continue processing
return []

def stateful_stream_harmony_tokens_to_openai_messages(
self,
request_id: str,
tokens: list[int],
available_tools: list[dict[str, Any]] | None = None,
tool_choice: str | None = None) -> list[Message]:
"""
Process tokens using stateful parsing.

This method maintains persistent state across multiple calls for the same request,
ensuring proper channel transitions and tool call handling.

Args:
request_id: Request ID to maintain state per request
tokens: New tokens from this iteration
available_tools: Available tools for filtering

Returns:
List of OpenAI Messages
"""
stream_state = self._stream_states.get(request_id, None)
if stream_state is None:
stream_state = self.create_stream_state(request_id, available_tools,
tool_choice)

try:
messages = stream_state.process_token_batch_to_messages(tokens)
return messages
except (HarmonyError, UnicodeDecodeError, ValueError):
logger.error(
f"Streaming: Failed to process token batch of {len(tokens)} tokens for request {request_id}",
)
logger.debug(f"Problematic streaming tokens: {tokens}")

return []

def create_openai_streaming_response(
self,
request_id: str,
Expand Down Expand Up @@ -1406,36 +1485,72 @@ def _is_tool_call_allowed(self, tool_call: dict[str, Any],
return True


async def handle_streaming_response(
harmony_adapter: HarmonyAdapter,
generator: RequestOutput,
request_id: str,
request: ChatCompletionRequest,
) -> AsyncGenerator[str, None]:
"""Handle streaming response with harmony format."""
_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None


def get_harmony_adapter():
global _SERVE_HARMONY_ADAPTER
if _SERVE_HARMONY_ADAPTER is None:
_SERVE_HARMONY_ADAPTER = HarmonyAdapter()

return _SERVE_HARMONY_ADAPTER


def handle_streaming_response(tools: List[ChatCompletionToolsParam],
tool_choice: str, outputs: List, model: str,
request_id: str, done: bool,
num_prompt_tokens: int):
first_iteration = True
async for res in generator:
output = res.outputs[0]
output = outputs[0]

# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
harmony_adapter = get_harmony_adapter()
if tools:
tools_dict = [tool.model_dump() for tool in tools]

# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
if request.tools:
tools_dict = [tool.model_dump() for tool in request.tools]
# Get tool_choice from request - if "none", don't pass tools to parser
if tool_choice == "none":
tools_for_parser = None
else:
tools_for_parser = tools_dict

# Get tool_choice from request - if "none", don't pass tools to parser
tool_choice = getattr(request, 'tool_choice', None)
if tool_choice == "none":
tools_for_parser = None
else:
tools_for_parser = tools_dict
# Create OpenAI streaming responses
try:
res = []
if done:
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)

# Create OpenAI streaming responses
try:
usage_info = _create_usage_info(num_prompt_tokens, outputs)

# Send final message with finish_reason
final_response = ChatCompletionStreamResponse(
model=model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
],
)

final_response_json = final_response.model_dump_json(
exclude_none=True)
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=model,
usage=usage_info)
final_usage_json = final_usage_chunk.model_dump_json(
exclude_none=True)
res.append(f"data: {final_response_json}\n\n")
res.append(f"data: {final_usage_json}\n\n")
else:
responses = harmony_adapter.create_openai_streaming_response(
request_id=request_id,
tokens=output.token_ids_diff,
available_tools=tools_for_parser,
model_name=request.model,
model_name=model,
tool_choice=tool_choice)
# Send first response after receiving the first output
if first_iteration:
Expand All @@ -1446,64 +1561,44 @@ async def handle_streaming_response(
delta=first_delta)

first_response = ChatCompletionStreamResponse(
model=request.model,
model=model,
choices=[choice],
)

response_json = first_response.model_dump_json(
exclude_none=True)
yield f"data: {response_json}\n\n"

for response in responses:
yield response

except Exception as e:
logger.error(f"Failed to create OpenAI streaming response: {e}")
logger.debug(f"Streaming error details: {traceback.format_exc()}")
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
raise e
res.append(f"data: {response_json}\n\n")

# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
res.extend(responses)

# Send final message with finish_reason
output = generator.outputs[0]
final_response = ChatCompletionStreamResponse(
model=request.model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
])
return res

yield f"data: {final_response.model_dump_json(exclude_unset=True)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Failed to create OpenAI streaming response: {e}")
logger.debug(f"Streaming error details: {traceback.format_exc()}")
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
raise e


async def handle_non_streaming_response(
harmony_adapter: HarmonyAdapter, promise: RequestOutput,
request: ChatCompletionRequest) -> ChatCompletionResponse:
def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
tool_choice: str, outputs: List, model: str,
num_prompt_tokens: int):
"""Handle non-streaming response with harmony format."""
# Get final result
await promise

# Parse harmony output to OpenAI format
# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
if request.tools:
tools_dict = [tool.model_dump() for tool in request.tools]
harmony_adapter = get_harmony_adapter()
if tools:
tools_dict = [tool.model_dump() for tool in tools]

# Get tool_choice from request - if "none", don't pass tools to parser
tool_choice = getattr(request, 'tool_choice', None)
if tool_choice == "none":
tools_for_parser = None
else:
tools_for_parser = tools_dict

output = promise.outputs[0]
output = outputs[0]
parsed_output = harmony_adapter.harmony_output_to_openai(
output.token_ids, tools_for_parser, tool_choice)

Expand All @@ -1518,11 +1613,11 @@ async def handle_non_streaming_response(
output.finish_reason)

# Create usage info from metrics (RequestOutput doesn't have usage in v1)
usage_info = _create_usage_info(promise)
usage_info = _create_usage_info(num_prompt_tokens, outputs)

# Create response
response = ChatCompletionResponse(
model=request.model,
model=model,
choices=[
ChatCompletionResponseChoice(
index=0,
Expand All @@ -1534,7 +1629,6 @@ async def handle_non_streaming_response(
# Optional: Log if harmony parsing failed (for debugging)
if parsed_output.get('_harmony_parsing_failed'):
logger.warning("⚠️ Harmony parsing fell back to raw text decoding")
logger.debug(f"request\n\n{request}")
logger.debug(f"response\n\n{response}\n")

return response
Expand Down Expand Up @@ -1567,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any],
return reason


def _create_usage_info(final_res: RequestOutput) -> UsageInfo:
def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo:
"""Create usage info from RequestOutput following serving_chat.py pattern."""
# Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)

# Calculate completion tokens from all outputs
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
num_generated_tokens = sum(len(output.token_ids) for output in outputs)

# Create usage info
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
Expand Down
Loading