5050 CompletionMessage ,
5151 Inference ,
5252 Message ,
53+ OpenAIAssistantMessageParam ,
54+ OpenAIDeveloperMessageParam ,
55+ OpenAIMessageParam ,
56+ OpenAISystemMessageParam ,
57+ OpenAIToolMessageParam ,
58+ OpenAIUserMessageParam ,
5359 SamplingParams ,
5460 StopReason ,
5561 SystemMessage ,
6874 BuiltinTool ,
6975 ToolCall ,
7076)
77+ from llama_stack .providers .utils .inference .openai_compat import (
78+ convert_message_to_openai_dict_new ,
79+ convert_openai_chat_completion_stream ,
80+ convert_tooldef_to_openai_tool ,
81+ )
7182from llama_stack .providers .utils .kvstore import KVStore
7283from llama_stack .providers .utils .telemetry import tracing
7384
@@ -177,12 +188,12 @@ async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
177188 return messages
178189
179190 async def create_and_execute_turn (self , request : AgentTurnCreateRequest ) -> AsyncGenerator :
191+ turn_id = str (uuid .uuid4 ())
180192 span = tracing .get_current_span ()
181193 if span :
182194 span .set_attribute ("session_id" , request .session_id )
183195 span .set_attribute ("agent_id" , self .agent_id )
184196 span .set_attribute ("request" , request .model_dump_json ())
185- turn_id = str (uuid .uuid4 ())
186197 span .set_attribute ("turn_id" , turn_id )
187198 if self .agent_config .name :
188199 span .set_attribute ("agent_name" , self .agent_config .name )
@@ -505,26 +516,93 @@ async def _run(
505516
506517 tool_calls = []
507518 content = ""
508- stop_reason = None
519+ stop_reason : StopReason | None = None
509520
510521 async with tracing .span ("inference" ) as span :
511522 if self .agent_config .name :
512523 span .set_attribute ("agent_name" , self .agent_config .name )
513- async for chunk in await self .inference_api .chat_completion (
514- self .agent_config .model ,
515- input_messages ,
516- tools = self .tool_defs ,
517- tool_prompt_format = self .agent_config .tool_config .tool_prompt_format ,
524+
525+ def _serialize_nested (value ):
526+ """Recursively serialize nested Pydantic models to dicts."""
527+ from pydantic import BaseModel
528+
529+ if isinstance (value , BaseModel ):
530+ return value .model_dump (mode = "json" )
531+ elif isinstance (value , dict ):
532+ return {k : _serialize_nested (v ) for k , v in value .items ()}
533+ elif isinstance (value , list ):
534+ return [_serialize_nested (item ) for item in value ]
535+ else :
536+ return value
537+
538+ def _add_type (openai_msg : dict ) -> OpenAIMessageParam :
539+ # Serialize any nested Pydantic models to plain dicts
540+ openai_msg = _serialize_nested (openai_msg )
541+
542+ role = openai_msg .get ("role" )
543+ if role == "user" :
544+ return OpenAIUserMessageParam (** openai_msg )
545+ elif role == "system" :
546+ return OpenAISystemMessageParam (** openai_msg )
547+ elif role == "assistant" :
548+ return OpenAIAssistantMessageParam (** openai_msg )
549+ elif role == "tool" :
550+ return OpenAIToolMessageParam (** openai_msg )
551+ elif role == "developer" :
552+ return OpenAIDeveloperMessageParam (** openai_msg )
553+ else :
554+ raise ValueError (f"Unknown message role: { role } " )
555+
556+ # Convert messages to OpenAI format
557+ openai_messages : list [OpenAIMessageParam ] = [
558+ _add_type (await convert_message_to_openai_dict_new (message )) for message in input_messages
559+ ]
560+
561+ # Convert tool definitions to OpenAI format
562+ openai_tools = [convert_tooldef_to_openai_tool (x ) for x in (self .tool_defs or [])]
563+
564+ # Extract tool_choice from tool_config for OpenAI compatibility
565+ # Note: tool_choice can only be provided when tools are also provided
566+ tool_choice = None
567+ if openai_tools and self .agent_config .tool_config and self .agent_config .tool_config .tool_choice :
568+ tc = self .agent_config .tool_config .tool_choice
569+ tool_choice_str = tc .value if hasattr (tc , "value" ) else str (tc )
570+ # Convert tool_choice to OpenAI format
571+ if tool_choice_str in ("auto" , "none" , "required" ):
572+ tool_choice = tool_choice_str
573+ else :
574+ # It's a specific tool name, wrap it in the proper format
575+ tool_choice = {"type" : "function" , "function" : {"name" : tool_choice_str }}
576+
577+ # Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
578+ temperature = getattr (getattr (sampling_params , "strategy" , None ), "temperature" , None )
579+ top_p = getattr (getattr (sampling_params , "strategy" , None ), "top_p" , None )
580+ max_tokens = getattr (sampling_params , "max_tokens" , None )
581+
582+ # Use OpenAI chat completion
583+ openai_stream = await self .inference_api .openai_chat_completion (
584+ model = self .agent_config .model ,
585+ messages = openai_messages ,
586+ tools = openai_tools if openai_tools else None ,
587+ tool_choice = tool_choice ,
518588 response_format = self .agent_config .response_format ,
589+ temperature = temperature ,
590+ top_p = top_p ,
591+ max_tokens = max_tokens ,
519592 stream = True ,
520- sampling_params = sampling_params ,
521- tool_config = self .agent_config .tool_config ,
522- ):
593+ )
594+
595+ # Convert OpenAI stream back to Llama Stack format
596+ response_stream = convert_openai_chat_completion_stream (
597+ openai_stream , enable_incremental_tool_calls = True
598+ )
599+
600+ async for chunk in response_stream :
523601 event = chunk .event
524602 if event .event_type == ChatCompletionResponseEventType .start :
525603 continue
526604 elif event .event_type == ChatCompletionResponseEventType .complete :
527- stop_reason = StopReason .end_of_turn
605+ stop_reason = event . stop_reason or StopReason .end_of_turn
528606 continue
529607
530608 delta = event .delta
@@ -533,7 +611,7 @@ async def _run(
533611 tool_calls .append (delta .tool_call )
534612 elif delta .parse_status == ToolCallParseStatus .failed :
535613 # If we cannot parse the tools, set the content to the unparsed raw text
536- content = delta .tool_call
614+ content = str ( delta .tool_call )
537615 if stream :
538616 yield AgentTurnResponseStreamChunk (
539617 event = AgentTurnResponseEvent (
@@ -560,9 +638,7 @@ async def _run(
560638 else :
561639 raise ValueError (f"Unexpected delta type { type (delta )} " )
562640
563- if event .stop_reason is not None :
564- stop_reason = event .stop_reason
565- span .set_attribute ("stop_reason" , stop_reason )
641+ span .set_attribute ("stop_reason" , stop_reason or StopReason .end_of_turn )
566642 span .set_attribute (
567643 "input" ,
568644 json .dumps ([json .loads (m .model_dump_json ()) for m in input_messages ]),
0 commit comments