Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 109 additions & 31 deletions src/google/adk/agents/parallel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
from __future__ import annotations

import asyncio
from typing import Any
import sys
from typing import AsyncGenerator
from typing import ClassVar
from typing import Dict
from typing import Type

from typing_extensions import override

Expand Down Expand Up @@ -49,6 +47,70 @@ def _create_branch_ctx_for_sub_agent(
return invocation_context


# TODO - remove once Python <3.11 is no longer supported.
async def _merge_agent_run_pre_3_11(
agent_runs: list[AsyncGenerator[Event, None]],
) -> AsyncGenerator[Event, None]:
"""Merges the agent run event generator.
This version works in Python 3.9 and 3.10 and uses custom replacement for
asyncio.TaskGroup for tasks cancellation and exception handling.

This implementation guarantees for each agent, it won't move on until the
generated event is processed by upstream runner.

Args:
agent_runs: A list of async generators that yield events from each agent.

Yields:
Event: The next event from the merged generator.
"""
sentinel = object()
queue = asyncio.Queue()

def propagate_exceptions(tasks):
# Propagate exceptions and errors from tasks.
for task in tasks:
if task.done():
# Ignore the result (None) of correctly finished tasks and re-raise
# exceptions and errors.
task.result()

# Agents are processed in parallel.
# Events for each agent are put on queue sequentially.
async def process_an_agent(events_for_one_agent):
try:
async for event in events_for_one_agent:
resume_signal = asyncio.Event()
await queue.put((event, resume_signal))
# Wait for upstream to consume event before generating new events.
await resume_signal.wait()
finally:
# Mark agent as finished.
await queue.put((sentinel, None))

tasks = []
try:
for events_for_one_agent in agent_runs:
tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent)))

sentinel_count = 0
# Run until all agents finished processing.
while sentinel_count < len(agent_runs):
propagate_exceptions(tasks)
event, resume_signal = await queue.get()
# Agent finished processing.
if event is sentinel:
sentinel_count += 1
else:
yield event
# Signal to agent that event has been processed by runner and it can
# continue now.
resume_signal.set()
finally:
for task in tasks:
task.cancel()


async def _merge_agent_run(
agent_runs: list[AsyncGenerator[Event, None]],
) -> AsyncGenerator[Event, None]:
Expand All @@ -63,30 +125,37 @@ async def _merge_agent_run(
Yields:
Event: The next event from the merged generator.
"""
tasks = [
asyncio.create_task(events_for_one_agent.__anext__())
for events_for_one_agent in agent_runs
]
pending_tasks = set(tasks)

while pending_tasks:
done, pending_tasks = await asyncio.wait(
pending_tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done:
try:
yield task.result()

# Find the generator that produced this event and move it on.
for i, original_task in enumerate(tasks):
if task == original_task:
new_task = asyncio.create_task(agent_runs[i].__anext__())
tasks[i] = new_task
pending_tasks.add(new_task)
break # stop iterating once found

except StopAsyncIteration:
continue
sentinel = object()
queue = asyncio.Queue()

# Agents are processed in parallel.
# Events for each agent are put on queue sequentially.
async def process_an_agent(events_for_one_agent):
try:
async for event in events_for_one_agent:
resume_signal = asyncio.Event()
await queue.put((event, resume_signal))
# Wait for upstream to consume event before generating new events.
await resume_signal.wait()
finally:
# Mark agent as finished.
await queue.put((sentinel, None))

async with asyncio.TaskGroup() as tg:
for events_for_one_agent in agent_runs:
tg.create_task(process_an_agent(events_for_one_agent))

sentinel_count = 0
# Run until all agents finished processing.
while sentinel_count < len(agent_runs):
event, resume_signal = await queue.get()
# Agent finished processing.
if event is sentinel:
sentinel_count += 1
else:
yield event
# Signal to agent that it should generate next event.
resume_signal.set()


class ParallelAgent(BaseAgent):
Expand All @@ -112,10 +181,19 @@ async def _run_async_impl(
)
for sub_agent in self.sub_agents
]

async with Aclosing(_merge_agent_run(agent_runs)) as agen:
async for event in agen:
yield event
try:
# TODO remove if once Python <3.11 is no longer supported.
if sys.version_info >= (3, 11):
async with Aclosing(_merge_agent_run(agent_runs)) as agen:
async for event in agen:
yield event
else:
async with Aclosing(_merge_agent_run_pre_3_11(agent_runs)) as agen:
async for event in agen:
yield event
finally:
for sub_agent_run in agent_runs:
await sub_agent_run.aclose()

@override
async def _run_live_impl(
Expand Down