Skip to content

Commit 7e48cc4

Browse files
refactor(agents): migrate to OpenAI chat completions API (#3323)
1 parent 426dc54 commit 7e48cc4

32 files changed

+12226
-15
lines changed

llama_stack/providers/inline/agents/meta_reference/agent_instance.py

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@
5050
CompletionMessage,
5151
Inference,
5252
Message,
53+
OpenAIAssistantMessageParam,
54+
OpenAIDeveloperMessageParam,
55+
OpenAIMessageParam,
56+
OpenAISystemMessageParam,
57+
OpenAIToolMessageParam,
58+
OpenAIUserMessageParam,
5359
SamplingParams,
5460
StopReason,
5561
SystemMessage,
@@ -68,6 +74,11 @@
6874
BuiltinTool,
6975
ToolCall,
7076
)
77+
from llama_stack.providers.utils.inference.openai_compat import (
78+
convert_message_to_openai_dict_new,
79+
convert_openai_chat_completion_stream,
80+
convert_tooldef_to_openai_tool,
81+
)
7182
from llama_stack.providers.utils.kvstore import KVStore
7283
from llama_stack.providers.utils.telemetry import tracing
7384

@@ -177,12 +188,12 @@ async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
177188
return messages
178189

179190
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
191+
turn_id = str(uuid.uuid4())
180192
span = tracing.get_current_span()
181193
if span:
182194
span.set_attribute("session_id", request.session_id)
183195
span.set_attribute("agent_id", self.agent_id)
184196
span.set_attribute("request", request.model_dump_json())
185-
turn_id = str(uuid.uuid4())
186197
span.set_attribute("turn_id", turn_id)
187198
if self.agent_config.name:
188199
span.set_attribute("agent_name", self.agent_config.name)
@@ -505,26 +516,93 @@ async def _run(
505516

506517
tool_calls = []
507518
content = ""
508-
stop_reason = None
519+
stop_reason: StopReason | None = None
509520

510521
async with tracing.span("inference") as span:
511522
if self.agent_config.name:
512523
span.set_attribute("agent_name", self.agent_config.name)
513-
async for chunk in await self.inference_api.chat_completion(
514-
self.agent_config.model,
515-
input_messages,
516-
tools=self.tool_defs,
517-
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
524+
525+
def _serialize_nested(value):
526+
"""Recursively serialize nested Pydantic models to dicts."""
527+
from pydantic import BaseModel
528+
529+
if isinstance(value, BaseModel):
530+
return value.model_dump(mode="json")
531+
elif isinstance(value, dict):
532+
return {k: _serialize_nested(v) for k, v in value.items()}
533+
elif isinstance(value, list):
534+
return [_serialize_nested(item) for item in value]
535+
else:
536+
return value
537+
538+
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
539+
# Serialize any nested Pydantic models to plain dicts
540+
openai_msg = _serialize_nested(openai_msg)
541+
542+
role = openai_msg.get("role")
543+
if role == "user":
544+
return OpenAIUserMessageParam(**openai_msg)
545+
elif role == "system":
546+
return OpenAISystemMessageParam(**openai_msg)
547+
elif role == "assistant":
548+
return OpenAIAssistantMessageParam(**openai_msg)
549+
elif role == "tool":
550+
return OpenAIToolMessageParam(**openai_msg)
551+
elif role == "developer":
552+
return OpenAIDeveloperMessageParam(**openai_msg)
553+
else:
554+
raise ValueError(f"Unknown message role: {role}")
555+
556+
# Convert messages to OpenAI format
557+
openai_messages: list[OpenAIMessageParam] = [
558+
_add_type(await convert_message_to_openai_dict_new(message)) for message in input_messages
559+
]
560+
561+
# Convert tool definitions to OpenAI format
562+
openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])]
563+
564+
# Extract tool_choice from tool_config for OpenAI compatibility
565+
# Note: tool_choice can only be provided when tools are also provided
566+
tool_choice = None
567+
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
568+
tc = self.agent_config.tool_config.tool_choice
569+
tool_choice_str = tc.value if hasattr(tc, "value") else str(tc)
570+
# Convert tool_choice to OpenAI format
571+
if tool_choice_str in ("auto", "none", "required"):
572+
tool_choice = tool_choice_str
573+
else:
574+
# It's a specific tool name, wrap it in the proper format
575+
tool_choice = {"type": "function", "function": {"name": tool_choice_str}}
576+
577+
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
578+
temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None)
579+
top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None)
580+
max_tokens = getattr(sampling_params, "max_tokens", None)
581+
582+
# Use OpenAI chat completion
583+
openai_stream = await self.inference_api.openai_chat_completion(
584+
model=self.agent_config.model,
585+
messages=openai_messages,
586+
tools=openai_tools if openai_tools else None,
587+
tool_choice=tool_choice,
518588
response_format=self.agent_config.response_format,
589+
temperature=temperature,
590+
top_p=top_p,
591+
max_tokens=max_tokens,
519592
stream=True,
520-
sampling_params=sampling_params,
521-
tool_config=self.agent_config.tool_config,
522-
):
593+
)
594+
595+
# Convert OpenAI stream back to Llama Stack format
596+
response_stream = convert_openai_chat_completion_stream(
597+
openai_stream, enable_incremental_tool_calls=True
598+
)
599+
600+
async for chunk in response_stream:
523601
event = chunk.event
524602
if event.event_type == ChatCompletionResponseEventType.start:
525603
continue
526604
elif event.event_type == ChatCompletionResponseEventType.complete:
527-
stop_reason = StopReason.end_of_turn
605+
stop_reason = event.stop_reason or StopReason.end_of_turn
528606
continue
529607

530608
delta = event.delta
@@ -533,7 +611,7 @@ async def _run(
533611
tool_calls.append(delta.tool_call)
534612
elif delta.parse_status == ToolCallParseStatus.failed:
535613
# If we cannot parse the tools, set the content to the unparsed raw text
536-
content = delta.tool_call
614+
content = str(delta.tool_call)
537615
if stream:
538616
yield AgentTurnResponseStreamChunk(
539617
event=AgentTurnResponseEvent(
@@ -560,9 +638,7 @@ async def _run(
560638
else:
561639
raise ValueError(f"Unexpected delta type {type(delta)}")
562640

563-
if event.stop_reason is not None:
564-
stop_reason = event.stop_reason
565-
span.set_attribute("stop_reason", stop_reason)
641+
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
566642
span.set_attribute(
567643
"input",
568644
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),

0 commit comments

Comments
 (0)