diff --git a/openai_agents/run_customer_service_client.py b/openai_agents/run_customer_service_client.py index f129d0d4..44a7de8b 100644 --- a/openai_agents/run_customer_service_client.py +++ b/openai_agents/run_customer_service_client.py @@ -32,12 +32,13 @@ async def main(): # Query the workflow for the chat history # If the workflow is not open, start a new one start = False + history = [] try: history = await handle.query( CustomerServiceWorkflow.get_chat_history, reject_condition=QueryRejectCondition.NOT_OPEN, ) - except WorkflowQueryRejectedError as e: + except WorkflowQueryRejectedError: start = True except RPCError as e: if e.status == RPCStatusCode.NOT_FOUND: @@ -64,7 +65,7 @@ async def main(): CustomerServiceWorkflow.process_user_message, message_input ) history.extend(new_history) - print(*new_history, sep="\n") + print(*new_history[1:], sep="\n") except WorkflowUpdateFailedError: print("** Stale conversation. Reloading...") length = len(history) diff --git a/openai_agents/run_hello_world_workflow.py b/openai_agents/run_hello_world_workflow.py index ee5dee90..566d525a 100644 --- a/openai_agents/run_hello_world_workflow.py +++ b/openai_agents/run_hello_world_workflow.py @@ -4,7 +4,6 @@ from temporalio.contrib.pydantic import pydantic_data_converter from openai_agents.workflows.hello_world_workflow import HelloWorldAgent -from openai_agents.workflows.research_bot_workflow import ResearchWorkflow async def main(): diff --git a/openai_agents/run_worker.py b/openai_agents/run_worker.py index de2693f4..9dbaef91 100644 --- a/openai_agents/run_worker.py +++ b/openai_agents/run_worker.py @@ -7,6 +7,7 @@ from temporalio.contrib.openai_agents import ( ModelActivity, ModelActivityParameters, + OpenAIAgentsTracingInterceptor, set_open_ai_agent_temporal_overrides, ) from temporalio.contrib.pydantic import pydantic_data_converter @@ -46,6 +47,7 @@ async def main(): ModelActivity().invoke_model_activity, get_weather, ], + interceptors=[OpenAIAgentsTracingInterceptor()], ) await worker.run() diff --git a/openai_agents/workflows/customer_service.py b/openai_agents/workflows/customer_service.py index 6a08f4ed..88f6e3cd 100644 --- a/openai_agents/workflows/customer_service.py +++ b/openai_agents/workflows/customer_service.py @@ -1,5 +1,7 @@ from __future__ import annotations as _annotations +from typing import Dict, Tuple + from agents import Agent, RunContextWrapper, function_tool, handoff from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX from pydantic import BaseModel @@ -23,19 +25,20 @@ class AirlineAgentContext(BaseModel): description_override="Lookup frequently asked questions.", ) async def faq_lookup_tool(question: str) -> str: - if "bag" in question or "baggage" in question: + question_lower = question.lower() + if "bag" in question_lower or "baggage" in question_lower: return ( "You are allowed to bring one bag on the plane. " "It must be under 50 pounds and 22 inches x 14 inches x 9 inches." ) - elif "seats" in question or "plane" in question: + elif "seats" in question_lower or "plane" in question_lower: return ( "There are 120 seats on the plane. " "There are 22 business class seats and 98 economy seats. " "Exit rows are rows 4 and 16. " "Rows 5-8 are Economy Plus, with extra legroom. " ) - elif "wifi" in question: + elif "wifi" in question_lower: return "We have free wifi on the plane, join Airline-Wifi" return "I'm sorry, I don't know the answer to that question." @@ -74,7 +77,9 @@ async def on_seat_booking_handoff( ### AGENTS -def init_agents() -> Agent[AirlineAgentContext]: +def init_agents() -> Tuple[ + Agent[AirlineAgentContext], Dict[str, Agent[AirlineAgentContext]] +]: """ Initialize the agents for the airline customer service workflow. :return: triage agent @@ -121,7 +126,9 @@ def init_agents() -> Agent[AirlineAgentContext]: faq_agent.handoffs.append(triage_agent) seat_booking_agent.handoffs.append(triage_agent) - return triage_agent + return triage_agent, { + agent.name: agent for agent in [faq_agent, seat_booking_agent, triage_agent] + } class ProcessUserMessageInput(BaseModel): diff --git a/openai_agents/workflows/customer_service_workflow.py b/openai_agents/workflows/customer_service_workflow.py index 3186c5ab..ea6fff8a 100644 --- a/openai_agents/workflows/customer_service_workflow.py +++ b/openai_agents/workflows/customer_service_workflow.py @@ -1,7 +1,10 @@ from __future__ import annotations as _annotations +import asyncio +from datetime import timedelta + from agents import ( - Agent, + HandoffCallItem, HandoffOutputItem, ItemHelpers, MessageOutputItem, @@ -12,6 +15,7 @@ TResponseInputItem, trace, ) +from pydantic import BaseModel, dataclasses from temporalio import workflow from openai_agents.workflows.customer_service import ( @@ -21,65 +25,128 @@ ) +@dataclasses.dataclass +class CustomerServiceWorkflowState: + printed_history: list[str] + current_agent_name: str + context: AirlineAgentContext + input_items: list[dict] # Store as plain dictionaries to avoid serialization issues + + @workflow.defn class CustomerServiceWorkflow: @workflow.init - def __init__(self, input_items: list[TResponseInputItem] | None = None): + def __init__( + self, customer_service_state: CustomerServiceWorkflowState | None = None + ): self.run_config = RunConfig() - self.chat_history: list[str] = [] - self.current_agent: Agent[AirlineAgentContext] = init_agents() - self.context = AirlineAgentContext() - self.input_items = [] if input_items is None else input_items - @workflow.run - async def run(self, input_items: list[TResponseInputItem] | None = None): - await workflow.wait_condition( - lambda: workflow.info().is_continue_as_new_suggested() - and workflow.all_handlers_finished() + starting_agent, self.agent_map = init_agents() + self.current_agent = ( + self.agent_map[customer_service_state.current_agent_name] + if customer_service_state + else starting_agent + ) + self.context = ( + customer_service_state.context + if customer_service_state + else AirlineAgentContext() ) - workflow.continue_as_new(self.input_items) + + self.printed_history: list[str] = ( + customer_service_state.printed_history if customer_service_state else [] + ) + + self.input_items = ( + customer_service_state.input_items if customer_service_state else [] + ) + + # Communication channels + self.user_input_queue: asyncio.Queue[str] = asyncio.Queue() + self.update_condition: asyncio.Condition = asyncio.Condition() + + @workflow.run + async def run( + self, customer_service_state: CustomerServiceWorkflowState | None = None + ): + while True: + with trace("Customer service", group_id=workflow.info().workflow_id): + user_input = await self.user_input_queue.get() + self.input_items.append({"content": user_input, "role": "user"}) + result = await Runner.run( + self.current_agent, + self.input_items, + context=self.context, + run_config=self.run_config, + ) + self.printed_history.append(f"Enter your message: {user_input}") + for new_item in result.new_items: + agent_name = new_item.agent.name + if isinstance(new_item, MessageOutputItem): + self.printed_history.append( + f"{agent_name}: {ItemHelpers.text_message_output(new_item)}" + ) + elif isinstance(new_item, HandoffOutputItem): + self.printed_history.append( + f"Handed off from {new_item.source_agent.name} to {new_item.target_agent.name}" + ) + elif isinstance(new_item, HandoffCallItem): + self.printed_history.append( + f"{agent_name}: Handed off to tool {new_item.raw_item.name}" + ) + elif isinstance(new_item, ToolCallItem): + self.printed_history.append(f"{agent_name}: Calling a tool") + elif isinstance(new_item, ToolCallOutputItem): + self.printed_history.append( + f"{agent_name}: Tool call output: {new_item.output}" + ) + else: + self.printed_history.append( + f"{agent_name}: Skipping item: {new_item.__class__.__name__}" + ) + self.input_items = result.to_input_list() + self.current_agent = result.last_agent + async with self.update_condition: + self.update_condition.notify_all() + + if workflow.info().is_continue_as_new_suggested(): + await workflow.wait_condition( + lambda: workflow.all_handlers_finished(), + timeout=timedelta(seconds=10), + timeout_summary="Continue as new timeout - deadlock avoidance", + ) + + # Convert input_items to plain dictionaries for serialization + serializable_input_items = [] + for item in self.input_items: + if hasattr(item, "model_dump"): + # Convert Pydantic objects to dictionaries + serializable_input_items.append(item.model_dump()) + else: + # Already a plain Python object + serializable_input_items.append(item) + workflow.continue_as_new( + CustomerServiceWorkflowState( + printed_history=self.printed_history, + current_agent_name=self.current_agent.name, + context=self.context, + input_items=serializable_input_items, + ) + ) @workflow.query def get_chat_history(self) -> list[str]: - return self.chat_history + return self.printed_history @workflow.update async def process_user_message(self, input: ProcessUserMessageInput) -> list[str]: - length = len(self.chat_history) - self.chat_history.append(f"User: {input.user_input}") - with trace("Customer service", group_id=workflow.info().workflow_id): - self.input_items.append({"content": input.user_input, "role": "user"}) - result = await Runner.run( - self.current_agent, - self.input_items, - context=self.context, - run_config=self.run_config, + length = len(self.printed_history) + self.user_input_queue.put_nowait(input.user_input) + async with self.update_condition: + await self.update_condition.wait_for( + lambda: len(self.printed_history) > length ) - - for new_item in result.new_items: - agent_name = new_item.agent.name - if isinstance(new_item, MessageOutputItem): - self.chat_history.append( - f"{agent_name}: {ItemHelpers.text_message_output(new_item)}" - ) - elif isinstance(new_item, HandoffOutputItem): - self.chat_history.append( - f"Handed off from {new_item.source_agent.name} to {new_item.target_agent.name}" - ) - elif isinstance(new_item, ToolCallItem): - self.chat_history.append(f"{agent_name}: Calling a tool") - elif isinstance(new_item, ToolCallOutputItem): - self.chat_history.append( - f"{agent_name}: Tool call output: {new_item.output}" - ) - else: - self.chat_history.append( - f"{agent_name}: Skipping item: {new_item.__class__.__name__}" - ) - self.input_items = result.to_input_list() - self.current_agent = result.last_agent - workflow.set_current_details("\n\n".join(self.chat_history)) - return self.chat_history[length:] + return self.printed_history[length:] @process_user_message.validator def validate_process_user_message(self, input: ProcessUserMessageInput) -> None: @@ -87,5 +154,5 @@ def validate_process_user_message(self, input: ProcessUserMessageInput) -> None: raise ValueError("User input cannot be empty.") if len(input.user_input) > 1000: raise ValueError("User input is too long. Please limit to 1000 characters.") - if input.chat_length != len(self.chat_history): + if input.chat_length != len(self.printed_history): raise ValueError("Stale chat history. Please refresh the chat.") diff --git a/openai_agents/workflows/research_agents/research_manager.py b/openai_agents/workflows/research_agents/research_manager.py index 19bdd224..356da1d7 100644 --- a/openai_agents/workflows/research_agents/research_manager.py +++ b/openai_agents/workflows/research_agents/research_manager.py @@ -6,7 +6,7 @@ with workflow.unsafe.imports_passed_through(): # TODO: Restore progress updates - from agents import RunConfig, Runner, custom_span, gen_trace_id, trace + from agents import RunConfig, Runner, custom_span, trace from openai_agents.workflows.research_agents.planner_agent import ( WebSearchItem, @@ -28,8 +28,7 @@ def __init__(self): self.writer_agent = new_writer_agent() async def run(self, query: str) -> str: - trace_id = gen_trace_id() - with trace("Research trace", trace_id=trace_id): + with trace("Research trace"): search_plan = await self._plan_searches(query) search_results = await self._perform_searches(search_plan) report = await self._write_report(query, search_results) diff --git a/openai_agents/workflows/tools_workflow.py b/openai_agents/workflows/tools_workflow.py index c9f80e9f..e8b28f75 100644 --- a/openai_agents/workflows/tools_workflow.py +++ b/openai_agents/workflows/tools_workflow.py @@ -2,7 +2,7 @@ from datetime import timedelta -from agents import Agent, Runner +from agents import Agent, Runner, trace from temporalio import workflow from temporalio.contrib import openai_agents as temporal_agents @@ -13,15 +13,16 @@ class ToolsWorkflow: @workflow.run async def run(self, question: str) -> str: - agent = Agent( - name="Hello world", - instructions="You are a helpful agent.", - tools=[ - temporal_agents.workflow.activity_as_tool( - get_weather, start_to_close_timeout=timedelta(seconds=10) - ) - ], - ) + with trace("Activity as tool"): + agent = Agent( + name="Hello world", + instructions="You are a helpful agent.", + tools=[ + temporal_agents.workflow.activity_as_tool( + get_weather, start_to_close_timeout=timedelta(seconds=10) + ) + ], + ) - result = await Runner.run(agent, input=question) - return result.final_output + result = await Runner.run(agent, input=question) + return result.final_output