|
2 | 2 |
|
3 | 3 | import asyncio
|
4 | 4 | import dataclasses
|
| 5 | +import inspect |
| 6 | +from asyncio import Task |
5 | 7 | from collections import defaultdict, deque
|
6 | 8 | from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
|
7 | 9 | from contextlib import asynccontextmanager, contextmanager
|
@@ -740,7 +742,6 @@ async def process_function_tools( # noqa: C901
|
740 | 742 | deferred_tool_results: dict[str, DeferredToolResult] = {}
|
741 | 743 | if build_run_context(ctx).tool_call_approved and ctx.deps.tool_call_results is not None:
|
742 | 744 | deferred_tool_results = ctx.deps.tool_call_results
|
743 |
| - |
744 | 745 | # Deferred tool calls are "run" as well, by reading their value from the tool call results
|
745 | 746 | calls_to_run.extend(tool_calls_by_kind['external'])
|
746 | 747 | calls_to_run.extend(tool_calls_by_kind['unapproved'])
|
@@ -819,47 +820,65 @@ async def _call_tools(
|
819 | 820 | for call in tool_calls:
|
820 | 821 | yield _messages.FunctionToolCallEvent(call)
|
821 | 822 |
|
822 |
| - # Run all tool tasks in parallel |
823 | 823 | with tracer.start_as_current_span(
|
824 | 824 | 'running tools',
|
825 | 825 | attributes={
|
826 | 826 | 'tools': [call.tool_name for call in tool_calls],
|
827 | 827 | 'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}',
|
828 | 828 | },
|
829 | 829 | ):
|
830 |
| - tasks = [ |
831 |
| - asyncio.create_task( |
832 |
| - _call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id), usage_limits), |
833 |
| - name=call.tool_name, |
834 |
| - ) |
835 |
| - for call in tool_calls |
836 |
| - ] |
837 |
| - |
838 |
| - pending = tasks |
839 |
| - while pending: |
840 |
| - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) |
841 |
| - for task in done: |
842 |
| - index = tasks.index(task) |
843 |
| - try: |
844 |
| - tool_part, tool_user_part = task.result() |
845 |
| - except exceptions.CallDeferred: |
846 |
| - deferred_calls_by_index[index] = 'external' |
847 |
| - except exceptions.ApprovalRequired: |
848 |
| - deferred_calls_by_index[index] = 'unapproved' |
849 |
| - else: |
850 |
| - yield _messages.FunctionToolResultEvent(tool_part) |
851 | 830 |
|
852 |
| - tool_parts_by_index[index] = tool_part |
853 |
| - if tool_user_part: |
854 |
| - user_parts_by_index[index] = tool_user_part |
| 831 | + async def handle_call_or_result( |
| 832 | + coro_or_task: Awaitable[ |
| 833 | + tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None] |
| 834 | + ] |
| 835 | + | Task[tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]], |
| 836 | + index: int, |
| 837 | + ) -> _messages.HandleResponseEvent | None: |
| 838 | + try: |
| 839 | + tool_part, tool_user_part = ( |
| 840 | + (await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result() |
| 841 | + ) |
| 842 | + except exceptions.CallDeferred: |
| 843 | + deferred_calls_by_index[index] = 'external' |
| 844 | + except exceptions.ApprovalRequired: |
| 845 | + deferred_calls_by_index[index] = 'unapproved' |
| 846 | + else: |
| 847 | + tool_parts_by_index[index] = tool_part |
| 848 | + if tool_user_part: |
| 849 | + user_parts_by_index[index] = tool_user_part |
| 850 | + |
| 851 | + return _messages.FunctionToolResultEvent(tool_part) |
| 852 | + |
| 853 | + if tool_manager.should_call_sequentially(tool_calls): |
| 854 | + for index, call in enumerate(tool_calls): |
| 855 | + if event := await handle_call_or_result( |
| 856 | + _call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id), usage_limits), |
| 857 | + index, |
| 858 | + ): |
| 859 | + yield event |
| 860 | + |
| 861 | + else: |
| 862 | + tasks = [ |
| 863 | + asyncio.create_task( |
| 864 | + _call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id), usage_limits), |
| 865 | + name=call.tool_name, |
| 866 | + ) |
| 867 | + for call in tool_calls |
| 868 | + ] |
| 869 | + |
| 870 | + pending = tasks |
| 871 | + while pending: |
| 872 | + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) |
| 873 | + for task in done: |
| 874 | + index = tasks.index(task) |
| 875 | + if event := await handle_call_or_result(coro_or_task=task, index=index): |
| 876 | + yield event |
855 | 877 |
|
856 | 878 | # We append the results at the end, rather than as they are received, to retain a consistent ordering
|
857 | 879 | # This is mostly just to simplify testing
|
858 |
| - for k in sorted(tool_parts_by_index): |
859 |
| - output_parts.append(tool_parts_by_index[k]) |
860 |
| - |
861 |
| - for k in sorted(user_parts_by_index): |
862 |
| - output_parts.append(user_parts_by_index[k]) |
| 880 | + output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)]) |
| 881 | + output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)]) |
863 | 882 |
|
864 | 883 | for k in sorted(deferred_calls_by_index):
|
865 | 884 | output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
|
|
0 commit comments