17
17
from __future__ import annotations
18
18
19
19
import asyncio
20
- from typing import Any
20
+ import sys
21
21
from typing import AsyncGenerator
22
22
from typing import ClassVar
23
- from typing import Dict
24
- from typing import Type
25
23
26
24
from typing_extensions import override
27
25
@@ -49,6 +47,70 @@ def _create_branch_ctx_for_sub_agent(
49
47
return invocation_context
50
48
51
49
50
+ # TODO - remove once Python <3.11 is no longer supported.
51
+ async def _merge_agent_run_pre_3_11 (
52
+ agent_runs : list [AsyncGenerator [Event , None ]],
53
+ ) -> AsyncGenerator [Event , None ]:
54
+ """Merges the agent run event generator.
55
+ This version works in Python 3.9 and 3.10 and uses custom replacement for
56
+ asyncio.TaskGroup for tasks cancellation and exception handling.
57
+
58
+ This implementation guarantees for each agent, it won't move on until the
59
+ generated event is processed by upstream runner.
60
+
61
+ Args:
62
+ agent_runs: A list of async generators that yield events from each agent.
63
+
64
+ Yields:
65
+ Event: The next event from the merged generator.
66
+ """
67
+ sentinel = object ()
68
+ queue = asyncio .Queue ()
69
+
70
+ def propagate_exceptions (tasks ):
71
+ # Propagate exceptions and errors from tasks.
72
+ for task in tasks :
73
+ if task .done ():
74
+ # Ignore the result (None) of correctly finished tasks and re-raise
75
+ # exceptions and errors.
76
+ task .result ()
77
+
78
+ # Agents are processed in parallel.
79
+ # Events for each agent are put on queue sequentially.
80
+ async def process_an_agent (events_for_one_agent ):
81
+ try :
82
+ async for event in events_for_one_agent :
83
+ resume_signal = asyncio .Event ()
84
+ await queue .put ((event , resume_signal ))
85
+ # Wait for upstream to consume event before generating new events.
86
+ await resume_signal .wait ()
87
+ finally :
88
+ # Mark agent as finished.
89
+ await queue .put ((sentinel , None ))
90
+
91
+ tasks = []
92
+ try :
93
+ for events_for_one_agent in agent_runs :
94
+ tasks .append (asyncio .create_task (process_an_agent (events_for_one_agent )))
95
+
96
+ sentinel_count = 0
97
+ # Run until all agents finished processing.
98
+ while sentinel_count < len (agent_runs ):
99
+ propagate_exceptions (tasks )
100
+ event , resume_signal = await queue .get ()
101
+ # Agent finished processing.
102
+ if event is sentinel :
103
+ sentinel_count += 1
104
+ else :
105
+ yield event
106
+ # Signal to agent that event has been processed by runner and it can
107
+ # continue now.
108
+ resume_signal .set ()
109
+ finally :
110
+ for task in tasks :
111
+ task .cancel ()
112
+
113
+
52
114
async def _merge_agent_run (
53
115
agent_runs : list [AsyncGenerator [Event , None ]],
54
116
) -> AsyncGenerator [Event , None ]:
@@ -63,30 +125,37 @@ async def _merge_agent_run(
63
125
Yields:
64
126
Event: The next event from the merged generator.
65
127
"""
66
- tasks = [
67
- asyncio .create_task (events_for_one_agent .__anext__ ())
68
- for events_for_one_agent in agent_runs
69
- ]
70
- pending_tasks = set (tasks )
71
-
72
- while pending_tasks :
73
- done , pending_tasks = await asyncio .wait (
74
- pending_tasks , return_when = asyncio .FIRST_COMPLETED
75
- )
76
- for task in done :
77
- try :
78
- yield task .result ()
79
-
80
- # Find the generator that produced this event and move it on.
81
- for i , original_task in enumerate (tasks ):
82
- if task == original_task :
83
- new_task = asyncio .create_task (agent_runs [i ].__anext__ ())
84
- tasks [i ] = new_task
85
- pending_tasks .add (new_task )
86
- break # stop iterating once found
87
-
88
- except StopAsyncIteration :
89
- continue
128
+ sentinel = object ()
129
+ queue = asyncio .Queue ()
130
+
131
+ # Agents are processed in parallel.
132
+ # Events for each agent are put on queue sequentially.
133
+ async def process_an_agent (events_for_one_agent ):
134
+ try :
135
+ async for event in events_for_one_agent :
136
+ resume_signal = asyncio .Event ()
137
+ await queue .put ((event , resume_signal ))
138
+ # Wait for upstream to consume event before generating new events.
139
+ await resume_signal .wait ()
140
+ finally :
141
+ # Mark agent as finished.
142
+ await queue .put ((sentinel , None ))
143
+
144
+ async with asyncio .TaskGroup () as tg :
145
+ for events_for_one_agent in agent_runs :
146
+ tg .create_task (process_an_agent (events_for_one_agent ))
147
+
148
+ sentinel_count = 0
149
+ # Run until all agents finished processing.
150
+ while sentinel_count < len (agent_runs ):
151
+ event , resume_signal = await queue .get ()
152
+ # Agent finished processing.
153
+ if event is sentinel :
154
+ sentinel_count += 1
155
+ else :
156
+ yield event
157
+ # Signal to agent that it should generate next event.
158
+ resume_signal .set ()
90
159
91
160
92
161
class ParallelAgent (BaseAgent ):
@@ -112,10 +181,19 @@ async def _run_async_impl(
112
181
)
113
182
for sub_agent in self .sub_agents
114
183
]
115
-
116
- async with Aclosing (_merge_agent_run (agent_runs )) as agen :
117
- async for event in agen :
118
- yield event
184
+ try :
185
+ # TODO remove if once Python <3.11 is no longer supported.
186
+ if sys .version_info >= (3 , 11 ):
187
+ async with Aclosing (_merge_agent_run (agent_runs )) as agen :
188
+ async for event in agen :
189
+ yield event
190
+ else :
191
+ async with Aclosing (_merge_agent_run_pre_3_11 (agent_runs )) as agen :
192
+ async for event in agen :
193
+ yield event
194
+ finally :
195
+ for sub_agent_run in agent_runs :
196
+ await sub_agent_run .aclose ()
119
197
120
198
@override
121
199
async def _run_live_impl (
0 commit comments