Skip to content

Commit eb956fd

Browse files
jawoszekcopybara-github
authored andcommitted
chore: add test for parallel agent to verify correct ordering of agents
PiperOrigin-RevId: 796949367
1 parent 018db79 commit eb956fd

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

tests/unittests/agents/test_parallel_agent.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,63 @@ async def test_run_async_branches(request: pytest.FixtureRequest):
135135
# Sub-agents should have different branches.
136136
assert events[2].branch != events[1].branch
137137
assert events[2].branch != events[0].branch
138+
139+
140+
class _TestingAgentWithMultipleEvents(BaseAgent):
141+
"""Mock agent for testing."""
142+
143+
processed_events: int = 0
144+
145+
@override
146+
async def _run_async_impl(
147+
self, ctx: InvocationContext
148+
) -> AsyncGenerator[Event, None]:
149+
event = Event(
150+
author=self.name,
151+
branch=ctx.branch,
152+
invocation_id=ctx.invocation_id,
153+
content=types.Content(
154+
parts=[types.Part(text=f'Hello, async {self.name}!')]
155+
),
156+
)
157+
for _ in range(0, 3):
158+
yield event
159+
self.processed_events += 1
160+
161+
162+
@pytest.mark.asyncio
163+
async def test_generating_one_event_per_agent_at_once(
164+
request: pytest.FixtureRequest,
165+
):
166+
# This test is to verify that the parallel agent won't generate more than one
167+
# event per agent at a time.
168+
agent1 = _TestingAgentWithMultipleEvents(
169+
name=f'{request.function.__name__}_test_agent_1'
170+
)
171+
agent2 = _TestingAgentWithMultipleEvents(
172+
name=f'{request.function.__name__}_test_agent_2'
173+
)
174+
parallel_agent = ParallelAgent(
175+
name=f'{request.function.__name__}_test_parallel_agent',
176+
sub_agents=[
177+
agent1,
178+
agent2,
179+
],
180+
)
181+
parent_ctx = await _create_parent_invocation_context(
182+
request.function.__name__, parallel_agent
183+
)
184+
185+
agen = parallel_agent.run_async(parent_ctx)
186+
# No event is generated yet.
187+
assert agent1.processed_events + agent2.processed_events == 0
188+
current_iter = 0
189+
async for _ in agen:
190+
processed_events = agent1.processed_events + agent2.processed_events
191+
# Depending on coroutine scheduling, agent increments the counter before
192+
# or after next iteration, so we have two possible/expected values of sum.
193+
assert (
194+
processed_events == current_iter or processed_events == current_iter - 1
195+
)
196+
# No more than one event is generated per agent after each iteration.
197+
current_iter += 1

0 commit comments

Comments
 (0)