Skip to content

Commit e5afab1

Browse files
jawoszekcopybara-github
authored andcommitted
fix: rework parallel_agent.py to always aclose async generators
See #1670 (comment) PiperOrigin-RevId: 795173267
1 parent 81a53b5 commit e5afab1

File tree

2 files changed

+217
-37
lines changed

2 files changed

+217
-37
lines changed

src/google/adk/agents/parallel_agent.py

Lines changed: 109 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717
from __future__ import annotations
1818

1919
import asyncio
20-
from typing import Any
20+
import sys
2121
from typing import AsyncGenerator
2222
from typing import ClassVar
23-
from typing import Dict
24-
from typing import Type
2523

2624
from typing_extensions import override
2725

@@ -49,6 +47,70 @@ def _create_branch_ctx_for_sub_agent(
4947
return invocation_context
5048

5149

50+
# TODO - remove once Python <3.11 is no longer supported.
51+
async def _merge_agent_run_pre_3_11(
52+
agent_runs: list[AsyncGenerator[Event, None]],
53+
) -> AsyncGenerator[Event, None]:
54+
"""Merges the agent run event generator.
55+
This version works in Python 3.9 and 3.10 and uses custom replacement for
56+
asyncio.TaskGroup for tasks cancellation and exception handling.
57+
58+
This implementation guarantees for each agent, it won't move on until the
59+
generated event is processed by upstream runner.
60+
61+
Args:
62+
agent_runs: A list of async generators that yield events from each agent.
63+
64+
Yields:
65+
Event: The next event from the merged generator.
66+
"""
67+
sentinel = object()
68+
queue = asyncio.Queue()
69+
70+
def propagate_exceptions(tasks):
71+
# Propagate exceptions and errors from tasks.
72+
for task in tasks:
73+
if task.done():
74+
# Ignore the result (None) of correctly finished tasks and re-raise
75+
# exceptions and errors.
76+
task.result()
77+
78+
# Agents are processed in parallel.
79+
# Events for each agent are put on queue sequentially.
80+
async def process_an_agent(events_for_one_agent):
81+
try:
82+
async for event in events_for_one_agent:
83+
resume_signal = asyncio.Event()
84+
await queue.put((event, resume_signal))
85+
# Wait for upstream to consume event before generating new events.
86+
await resume_signal.wait()
87+
finally:
88+
# Mark agent as finished.
89+
await queue.put((sentinel, None))
90+
91+
tasks = []
92+
try:
93+
for events_for_one_agent in agent_runs:
94+
tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent)))
95+
96+
sentinel_count = 0
97+
# Run until all agents finished processing.
98+
while sentinel_count < len(agent_runs):
99+
propagate_exceptions(tasks)
100+
event, resume_signal = await queue.get()
101+
# Agent finished processing.
102+
if event is sentinel:
103+
sentinel_count += 1
104+
else:
105+
yield event
106+
# Signal to agent that event has been processed by runner and it can
107+
# continue now.
108+
resume_signal.set()
109+
finally:
110+
for task in tasks:
111+
task.cancel()
112+
113+
52114
async def _merge_agent_run(
53115
agent_runs: list[AsyncGenerator[Event, None]],
54116
) -> AsyncGenerator[Event, None]:
@@ -63,30 +125,37 @@ async def _merge_agent_run(
63125
Yields:
64126
Event: The next event from the merged generator.
65127
"""
66-
tasks = [
67-
asyncio.create_task(events_for_one_agent.__anext__())
68-
for events_for_one_agent in agent_runs
69-
]
70-
pending_tasks = set(tasks)
71-
72-
while pending_tasks:
73-
done, pending_tasks = await asyncio.wait(
74-
pending_tasks, return_when=asyncio.FIRST_COMPLETED
75-
)
76-
for task in done:
77-
try:
78-
yield task.result()
79-
80-
# Find the generator that produced this event and move it on.
81-
for i, original_task in enumerate(tasks):
82-
if task == original_task:
83-
new_task = asyncio.create_task(agent_runs[i].__anext__())
84-
tasks[i] = new_task
85-
pending_tasks.add(new_task)
86-
break # stop iterating once found
87-
88-
except StopAsyncIteration:
89-
continue
128+
sentinel = object()
129+
queue = asyncio.Queue()
130+
131+
# Agents are processed in parallel.
132+
# Events for each agent are put on queue sequentially.
133+
async def process_an_agent(events_for_one_agent):
134+
try:
135+
async for event in events_for_one_agent:
136+
resume_signal = asyncio.Event()
137+
await queue.put((event, resume_signal))
138+
# Wait for upstream to consume event before generating new events.
139+
await resume_signal.wait()
140+
finally:
141+
# Mark agent as finished.
142+
await queue.put((sentinel, None))
143+
144+
async with asyncio.TaskGroup() as tg:
145+
for events_for_one_agent in agent_runs:
146+
tg.create_task(process_an_agent(events_for_one_agent))
147+
148+
sentinel_count = 0
149+
# Run until all agents finished processing.
150+
while sentinel_count < len(agent_runs):
151+
event, resume_signal = await queue.get()
152+
# Agent finished processing.
153+
if event is sentinel:
154+
sentinel_count += 1
155+
else:
156+
yield event
157+
# Signal to agent that it should generate next event.
158+
resume_signal.set()
90159

91160

92161
class ParallelAgent(BaseAgent):
@@ -112,10 +181,19 @@ async def _run_async_impl(
112181
)
113182
for sub_agent in self.sub_agents
114183
]
115-
116-
async with Aclosing(_merge_agent_run(agent_runs)) as agen:
117-
async for event in agen:
118-
yield event
184+
try:
185+
# TODO remove if once Python <3.11 is no longer supported.
186+
if sys.version_info >= (3, 11):
187+
async with Aclosing(_merge_agent_run(agent_runs)) as agen:
188+
async for event in agen:
189+
yield event
190+
else:
191+
async with Aclosing(_merge_agent_run_pre_3_11(agent_runs)) as agen:
192+
async for event in agen:
193+
yield event
194+
finally:
195+
for sub_agent_run in agent_runs:
196+
await sub_agent_run.aclose()
119197

120198
@override
121199
async def _run_live_impl(

tests/unittests/agents/test_parallel_agent.py

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,8 @@ class _TestingAgent(BaseAgent):
3333
delay: float = 0
3434
"""The delay before the agent generates an event."""
3535

36-
@override
37-
async def _run_async_impl(
38-
self, ctx: InvocationContext
39-
) -> AsyncGenerator[Event, None]:
40-
await asyncio.sleep(self.delay)
41-
yield Event(
36+
def event(self, ctx: InvocationContext):
37+
return Event(
4238
author=self.name,
4339
branch=ctx.branch,
4440
invocation_id=ctx.invocation_id,
@@ -47,6 +43,13 @@ async def _run_async_impl(
4743
),
4844
)
4945

46+
@override
47+
async def _run_async_impl(
48+
self, ctx: InvocationContext
49+
) -> AsyncGenerator[Event, None]:
50+
await asyncio.sleep(self.delay)
51+
yield self.event(ctx)
52+
5053

5154
async def _create_parent_invocation_context(
5255
test_name: str, agent: BaseAgent
@@ -135,3 +138,102 @@ async def test_run_async_branches(request: pytest.FixtureRequest):
135138
# Sub-agents should have different branches.
136139
assert events[2].branch != events[1].branch
137140
assert events[2].branch != events[0].branch
141+
142+
143+
class _TestingAgentWithMultipleEvents(_TestingAgent):
144+
"""Mock agent for testing."""
145+
146+
@override
147+
async def _run_async_impl(
148+
self, ctx: InvocationContext
149+
) -> AsyncGenerator[Event, None]:
150+
for _ in range(0, 3):
151+
event = self.event(ctx)
152+
yield event
153+
# Check that the event was processed by the consumer.
154+
assert event.custom_metadata is not None
155+
assert event.custom_metadata['processed']
156+
157+
158+
@pytest.mark.asyncio
159+
async def test_generating_one_event_per_agent_at_once(
160+
request: pytest.FixtureRequest,
161+
):
162+
# This test is to verify that the parallel agent won't generate more than one
163+
# event per agent at a time.
164+
agent1 = _TestingAgentWithMultipleEvents(
165+
name=f'{request.function.__name__}_test_agent_1'
166+
)
167+
agent2 = _TestingAgentWithMultipleEvents(
168+
name=f'{request.function.__name__}_test_agent_2'
169+
)
170+
parallel_agent = ParallelAgent(
171+
name=f'{request.function.__name__}_test_parallel_agent',
172+
sub_agents=[
173+
agent1,
174+
agent2,
175+
],
176+
)
177+
parent_ctx = await _create_parent_invocation_context(
178+
request.function.__name__, parallel_agent
179+
)
180+
181+
agen = parallel_agent.run_async(parent_ctx)
182+
async for event in agen:
183+
event.custom_metadata = {'processed': True}
184+
# Asserts on event are done in _TestingAgentWithMultipleEvents.
185+
186+
187+
class _TestingAgentWithException(_TestingAgent):
188+
"""Mock agent for testing."""
189+
190+
@override
191+
async def _run_async_impl(
192+
self, ctx: InvocationContext
193+
) -> AsyncGenerator[Event, None]:
194+
yield self.event(ctx)
195+
raise Exception()
196+
197+
198+
class _TestingAgentInfiniteEvents(_TestingAgent):
199+
"""Mock agent for testing."""
200+
201+
@override
202+
async def _run_async_impl(
203+
self, ctx: InvocationContext
204+
) -> AsyncGenerator[Event, None]:
205+
while True:
206+
yield self.event(ctx)
207+
208+
209+
@pytest.mark.asyncio
210+
async def test_stop_agent_if_sub_agent_fails(
211+
request: pytest.FixtureRequest,
212+
):
213+
# This test is to verify that the parallel agent and subagents will all stop
214+
# processing and throw exception to top level runner in case of exception.
215+
agent1 = _TestingAgentWithException(
216+
name=f'{request.function.__name__}_test_agent_1'
217+
)
218+
agent2 = _TestingAgentInfiniteEvents(
219+
name=f'{request.function.__name__}_test_agent_2'
220+
)
221+
parallel_agent = ParallelAgent(
222+
name=f'{request.function.__name__}_test_parallel_agent',
223+
sub_agents=[
224+
agent1,
225+
agent2,
226+
],
227+
)
228+
parent_ctx = await _create_parent_invocation_context(
229+
request.function.__name__, parallel_agent
230+
)
231+
232+
agen = parallel_agent.run_async(parent_ctx)
233+
# We expect to receive an exception from one of subagents.
234+
# The exception should be propagated to root agent and other subagents.
235+
# Otherwise we'll have an infinite loop.
236+
with pytest.raises(Exception):
237+
async for _ in agen:
238+
# The infinite agent could iterate a few times depending on scheduling.
239+
pass

0 commit comments

Comments
 (0)