Skip to content

Commit b3578c0

Browse files
jawoszekcopybara-github
authored andcommitted
chore: add test for parallel agent to verify correct handling of exceptions
PiperOrigin-RevId: 797309525
1 parent 81a53b5 commit b3578c0

File tree

1 file changed

+108
-6
lines changed

1 file changed

+108
-6
lines changed

tests/unittests/agents/test_parallel_agent.py

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,8 @@ class _TestingAgent(BaseAgent):
3333
delay: float = 0
3434
"""The delay before the agent generates an event."""
3535

36-
@override
37-
async def _run_async_impl(
38-
self, ctx: InvocationContext
39-
) -> AsyncGenerator[Event, None]:
40-
await asyncio.sleep(self.delay)
41-
yield Event(
36+
def event(self, ctx: InvocationContext):
37+
return Event(
4238
author=self.name,
4339
branch=ctx.branch,
4440
invocation_id=ctx.invocation_id,
@@ -47,6 +43,13 @@ async def _run_async_impl(
4743
),
4844
)
4945

46+
@override
47+
async def _run_async_impl(
48+
self, ctx: InvocationContext
49+
) -> AsyncGenerator[Event, None]:
50+
await asyncio.sleep(self.delay)
51+
yield self.event(ctx)
52+
5053

5154
async def _create_parent_invocation_context(
5255
test_name: str, agent: BaseAgent
@@ -135,3 +138,102 @@ async def test_run_async_branches(request: pytest.FixtureRequest):
135138
# Sub-agents should have different branches.
136139
assert events[2].branch != events[1].branch
137140
assert events[2].branch != events[0].branch
141+
142+
143+
class _TestingAgentWithMultipleEvents(_TestingAgent):
144+
"""Mock agent for testing."""
145+
146+
@override
147+
async def _run_async_impl(
148+
self, ctx: InvocationContext
149+
) -> AsyncGenerator[Event, None]:
150+
for _ in range(0, 3):
151+
event = self.event(ctx)
152+
yield event
153+
# Check that the event was processed by the consumer.
154+
assert event.custom_metadata is not None
155+
assert event.custom_metadata['processed']
156+
157+
158+
@pytest.mark.asyncio
159+
async def test_generating_one_event_per_agent_at_once(
160+
request: pytest.FixtureRequest,
161+
):
162+
# This test is to verify that the parallel agent won't generate more than one
163+
# event per agent at a time.
164+
agent1 = _TestingAgentWithMultipleEvents(
165+
name=f'{request.function.__name__}_test_agent_1'
166+
)
167+
agent2 = _TestingAgentWithMultipleEvents(
168+
name=f'{request.function.__name__}_test_agent_2'
169+
)
170+
parallel_agent = ParallelAgent(
171+
name=f'{request.function.__name__}_test_parallel_agent',
172+
sub_agents=[
173+
agent1,
174+
agent2,
175+
],
176+
)
177+
parent_ctx = await _create_parent_invocation_context(
178+
request.function.__name__, parallel_agent
179+
)
180+
181+
agen = parallel_agent.run_async(parent_ctx)
182+
async for event in agen:
183+
event.custom_metadata = {'processed': True}
184+
# Asserts on event are done in _TestingAgentWithMultipleEvents.
185+
186+
187+
class _TestingAgentWithException(_TestingAgent):
188+
"""Mock agent for testing."""
189+
190+
@override
191+
async def _run_async_impl(
192+
self, ctx: InvocationContext
193+
) -> AsyncGenerator[Event, None]:
194+
yield self.event(ctx)
195+
raise Exception()
196+
197+
198+
class _TestingAgentInfiniteEvents(_TestingAgent):
199+
"""Mock agent for testing."""
200+
201+
@override
202+
async def _run_async_impl(
203+
self, ctx: InvocationContext
204+
) -> AsyncGenerator[Event, None]:
205+
while True:
206+
yield self.event(ctx)
207+
208+
209+
@pytest.mark.asyncio
210+
async def test_stop_agent_if_sub_agent_fails(
211+
request: pytest.FixtureRequest,
212+
):
213+
# This test is to verify that the parallel agent and subagents will all stop
214+
# processing and throw exception to top level runner in case of exception.
215+
agent1 = _TestingAgentWithException(
216+
name=f'{request.function.__name__}_test_agent_1'
217+
)
218+
agent2 = _TestingAgentInfiniteEvents(
219+
name=f'{request.function.__name__}_test_agent_2'
220+
)
221+
parallel_agent = ParallelAgent(
222+
name=f'{request.function.__name__}_test_parallel_agent',
223+
sub_agents=[
224+
agent1,
225+
agent2,
226+
],
227+
)
228+
parent_ctx = await _create_parent_invocation_context(
229+
request.function.__name__, parallel_agent
230+
)
231+
232+
agen = parallel_agent.run_async(parent_ctx)
233+
# We expect to receive an exception from one of subagents.
234+
# The exception should be propagated to root agent and other subagents.
235+
# Otherwise we'll have an infinite loop.
236+
with pytest.raises(Exception):
237+
async for _ in agen:
238+
# The infinite agent could iterate a few times depending on scheduling.
239+
pass

0 commit comments

Comments
 (0)