Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 70 additions & 19 deletions tests/unittests/agents/test_parallel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@ class _TestingAgent(BaseAgent):
delay: float = 0
"""The delay before the agent generates an event."""

@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
await asyncio.sleep(self.delay)
yield Event(
def event(self, ctx: InvocationContext):
return Event(
author=self.name,
branch=ctx.branch,
invocation_id=ctx.invocation_id,
Expand All @@ -47,6 +43,13 @@ async def _run_async_impl(
),
)

@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
await asyncio.sleep(self.delay)
yield self.event(ctx)


async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent
Expand Down Expand Up @@ -137,26 +140,19 @@ async def test_run_async_branches(request: pytest.FixtureRequest):
assert events[2].branch != events[0].branch


class _TestingAgentWithMultipleEvents(BaseAgent):
class _TestingAgentWithMultipleEvents(_TestingAgent):
"""Mock agent for testing."""

@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
for _ in range(0, 3):
event = Event(
author=self.name,
branch=ctx.branch,
invocation_id=ctx.invocation_id,
content=types.Content(
parts=[types.Part(text=f'Hello, async {self.name}!')]
),
)
yield event
# Check that the event was processed by the consumer.
assert event.custom_metadata is not None
assert event.custom_metadata['processed']
event = self.event(ctx)
yield event
# Check that the event was processed by the consumer.
assert event.custom_metadata is not None
assert event.custom_metadata['processed']


@pytest.mark.asyncio
Expand Down Expand Up @@ -186,3 +182,58 @@ async def test_generating_one_event_per_agent_at_once(
async for event in agen:
event.custom_metadata = {'processed': True}
# Asserts on event are done in _TestingAgentWithMultipleEvents.


class _TestingAgentWithException(_TestingAgent):
"""Mock agent for testing."""

@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
yield self.event(ctx)
raise Exception()


class _TestingAgentInfiniteEvents(_TestingAgent):
"""Mock agent for testing."""

@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
while True:
yield self.event(ctx)


@pytest.mark.asyncio
async def test_stop_agent_if_sub_agent_fails(
request: pytest.FixtureRequest,
):
# This test is to verify that the parallel agent and subagents will all stop
# processing and throw exception to top level runner in case of exception.
agent1 = _TestingAgentWithException(
name=f'{request.function.__name__}_test_agent_1'
)
agent2 = _TestingAgentInfiniteEvents(
name=f'{request.function.__name__}_test_agent_2'
)
parallel_agent = ParallelAgent(
name=f'{request.function.__name__}_test_parallel_agent',
sub_agents=[
agent1,
agent2,
],
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, parallel_agent
)

agen = parallel_agent.run_async(parent_ctx)
# We expect to receive an exception from one of subagents.
# The exception should be propagated to root agent and other subagents.
# Otherwise we'll have an infinite loop.
with pytest.raises(Exception):
async for _ in agen:
# The infinite agent could iterate a few times depending on scheduling.
pass