Skip to content

Commit 5b312a8

Browse files
authored
feat(responses): improve streaming for function calls (#3124)
Emit streaming events for function calls ## Test Plan Improved the test case
1 parent d6ae547 commit 5b312a8

File tree

3 files changed

+250
-33
lines changed

3 files changed

+250
-33
lines changed

llama_stack/providers/inline/agents/meta_reference/openai_responses.py

Lines changed: 127 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
OpenAIResponseObjectStream,
3434
OpenAIResponseObjectStreamResponseCompleted,
3535
OpenAIResponseObjectStreamResponseCreated,
36+
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
37+
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
38+
OpenAIResponseObjectStreamResponseOutputItemAdded,
39+
OpenAIResponseObjectStreamResponseOutputItemDone,
3640
OpenAIResponseObjectStreamResponseOutputTextDelta,
3741
OpenAIResponseOutput,
3842
OpenAIResponseOutputMessageContent,
@@ -73,7 +77,9 @@
7377
from llama_stack.apis.vector_io import VectorIO
7478
from llama_stack.log import get_logger
7579
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
76-
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
80+
from llama_stack.providers.utils.inference.openai_compat import (
81+
convert_tooldef_to_openai_tool,
82+
)
7783
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
7884

7985
logger = get_logger(name=__name__, category="openai_responses")
@@ -82,7 +88,7 @@
8288

8389

8490
async def _convert_response_content_to_chat_content(
85-
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
91+
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
8692
) -> str | list[OpenAIChatCompletionContentPartParam]:
8793
"""
8894
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
@@ -150,7 +156,9 @@ async def _convert_response_input_to_chat_messages(
150156
return messages
151157

152158

153-
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
159+
async def _convert_chat_choice_to_response_message(
160+
choice: OpenAIChoice,
161+
) -> OpenAIResponseMessage:
154162
"""
155163
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
156164
"""
@@ -172,7 +180,9 @@ async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> Open
172180
)
173181

174182

175-
async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam:
183+
async def _convert_response_text_to_chat_response_format(
184+
text: OpenAIResponseText,
185+
) -> OpenAIResponseFormatParam:
176186
"""
177187
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
178188
"""
@@ -228,7 +238,9 @@ def __init__(
228238
self.vector_io_api = vector_io_api
229239

230240
async def _prepend_previous_response(
231-
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
241+
self,
242+
input: str | list[OpenAIResponseInput],
243+
previous_response_id: str | None = None,
232244
):
233245
if previous_response_id:
234246
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
@@ -446,6 +458,8 @@ async def _create_streaming_response(
446458

447459
# Create a placeholder message item for delta events
448460
message_item_id = f"msg_{uuid.uuid4()}"
461+
# Track tool call items for streaming events
462+
tool_call_item_ids: dict[int, str] = {}
449463

450464
async for chunk in completion_result:
451465
chat_response_id = chunk.id
@@ -472,18 +486,62 @@ async def _create_streaming_response(
472486
if chunk_choice.delta.tool_calls:
473487
for tool_call in chunk_choice.delta.tool_calls:
474488
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
475-
if response_tool_call:
476-
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
477-
if tool_call.function.arguments:
478-
# Guard against an initial None argument before we concatenate
479-
response_tool_call.function.arguments = (
480-
response_tool_call.function.arguments or ""
481-
) + tool_call.function.arguments
482-
else:
489+
# Create new tool call entry if this is the first chunk for this index
490+
is_new_tool_call = response_tool_call is None
491+
if is_new_tool_call:
483492
tool_call_dict: dict[str, Any] = tool_call.model_dump()
484493
tool_call_dict.pop("type", None)
485494
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
486-
chat_response_tool_calls[tool_call.index] = response_tool_call
495+
chat_response_tool_calls[tool_call.index] = response_tool_call
496+
497+
# Create item ID for this tool call for streaming events
498+
tool_call_item_id = f"fc_{uuid.uuid4()}"
499+
tool_call_item_ids[tool_call.index] = tool_call_item_id
500+
501+
# Emit output_item.added event for the new function call
502+
sequence_number += 1
503+
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
504+
arguments="", # Will be filled incrementally via delta events
505+
call_id=tool_call.id or "",
506+
name=tool_call.function.name if tool_call.function else "",
507+
id=tool_call_item_id,
508+
status="in_progress",
509+
)
510+
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
511+
response_id=response_id,
512+
item=function_call_item,
513+
output_index=len(output_messages),
514+
sequence_number=sequence_number,
515+
)
516+
517+
# Stream function call arguments as they arrive
518+
if tool_call.function and tool_call.function.arguments:
519+
tool_call_item_id = tool_call_item_ids[tool_call.index]
520+
sequence_number += 1
521+
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
522+
delta=tool_call.function.arguments,
523+
item_id=tool_call_item_id,
524+
output_index=len(output_messages),
525+
sequence_number=sequence_number,
526+
)
527+
528+
# Accumulate arguments for final response (only for subsequent chunks)
529+
if not is_new_tool_call:
530+
response_tool_call.function.arguments = (
531+
response_tool_call.function.arguments or ""
532+
) + tool_call.function.arguments
533+
534+
# Emit function_call_arguments.done events for completed tool calls
535+
for tool_call_index in sorted(chat_response_tool_calls.keys()):
536+
tool_call_item_id = tool_call_item_ids[tool_call_index]
537+
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
538+
sequence_number += 1
539+
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(
540+
arguments=final_arguments,
541+
item_id=tool_call_item_id,
542+
output_index=len(output_messages),
543+
sequence_number=sequence_number,
544+
)
487545

488546
# Convert collected chunks to complete response
489547
if chat_response_tool_calls:
@@ -532,18 +590,56 @@ async def _create_streaming_response(
532590
tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx)
533591
if tool_call_log:
534592
output_messages.append(tool_call_log)
593+
594+
# Emit output_item.done event for completed non-function tool call
595+
# Find the item_id for this tool call
596+
matching_item_id = None
597+
for index, item_id in tool_call_item_ids.items():
598+
response_tool_call = chat_response_tool_calls.get(index)
599+
if response_tool_call and response_tool_call.id == tool_call.id:
600+
matching_item_id = item_id
601+
break
602+
603+
if matching_item_id:
604+
sequence_number += 1
605+
yield OpenAIResponseObjectStreamResponseOutputItemDone(
606+
response_id=response_id,
607+
item=tool_call_log,
608+
output_index=len(output_messages) - 1,
609+
sequence_number=sequence_number,
610+
)
611+
535612
if tool_response_message:
536613
next_turn_messages.append(tool_response_message)
537614

538615
for tool_call in function_tool_calls:
539-
output_messages.append(
540-
OpenAIResponseOutputMessageFunctionToolCall(
541-
arguments=tool_call.function.arguments or "",
542-
call_id=tool_call.id,
543-
name=tool_call.function.name or "",
544-
id=f"fc_{uuid.uuid4()}",
545-
status="completed",
546-
)
616+
# Find the item_id for this tool call from our tracking dictionary
617+
matching_item_id = None
618+
for index, item_id in tool_call_item_ids.items():
619+
response_tool_call = chat_response_tool_calls.get(index)
620+
if response_tool_call and response_tool_call.id == tool_call.id:
621+
matching_item_id = item_id
622+
break
623+
624+
# Use existing item_id or create new one if not found
625+
final_item_id = matching_item_id or f"fc_{uuid.uuid4()}"
626+
627+
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
628+
arguments=tool_call.function.arguments or "",
629+
call_id=tool_call.id,
630+
name=tool_call.function.name or "",
631+
id=final_item_id,
632+
status="completed",
633+
)
634+
output_messages.append(function_call_item)
635+
636+
# Emit output_item.done event for completed function call
637+
sequence_number += 1
638+
yield OpenAIResponseObjectStreamResponseOutputItemDone(
639+
response_id=response_id,
640+
item=function_call_item,
641+
output_index=len(output_messages) - 1,
642+
sequence_number=sequence_number,
547643
)
548644

549645
if not function_tool_calls and not non_function_tool_calls:
@@ -779,7 +875,8 @@ async def _execute_tool_call(
779875
)
780876
elif function.name == "knowledge_search":
781877
response_file_search_tool = next(
782-
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None
878+
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
879+
None,
783880
)
784881
if response_file_search_tool:
785882
# Use vector_stores.search API instead of knowledge_search tool
@@ -798,7 +895,9 @@ async def _execute_tool_call(
798895
error_exc = e
799896

800897
if function.name in ctx.mcp_tool_to_server:
801-
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
898+
from llama_stack.apis.agents.openai_responses import (
899+
OpenAIResponseOutputMessageMCPCall,
900+
)
802901

803902
message = OpenAIResponseOutputMessageMCPCall(
804903
id=tool_call_id,
@@ -850,7 +949,10 @@ async def _execute_tool_call(
850949
if isinstance(result.content, str):
851950
content = result.content
852951
elif isinstance(result.content, list):
853-
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
952+
from llama_stack.apis.common.content_types import (
953+
ImageContentItem,
954+
TextContentItem,
955+
)
854956

855957
content = []
856958
for item in result.content:

0 commit comments

Comments
 (0)