|
1 | 1 | """Tool execution functionality for the event loop."""
|
2 | 2 |
|
3 | 3 | import logging
|
| 4 | +import queue |
| 5 | +import threading |
4 | 6 | import time
|
5 |
| -from concurrent.futures import TimeoutError |
6 |
| -from typing import Any, Callable, List, Optional, Tuple |
| 7 | +from typing import Any, Callable, Generator, Optional, cast |
7 | 8 |
|
8 | 9 | from opentelemetry import trace
|
9 | 10 |
|
|
19 | 20 |
|
20 | 21 | def run_tools(
|
21 | 22 | handler: Callable[[ToolUse], ToolResult],
|
22 |
| - tool_uses: List[ToolUse], |
| 23 | + tool_uses: list[ToolUse], |
23 | 24 | event_loop_metrics: EventLoopMetrics,
|
24 |
| - request_state: Any, |
25 |
| - invalid_tool_use_ids: List[str], |
26 |
| - tool_results: List[ToolResult], |
| 25 | + invalid_tool_use_ids: list[str], |
| 26 | + tool_results: list[ToolResult], |
27 | 27 | cycle_trace: Trace,
|
28 | 28 | parent_span: Optional[trace.Span] = None,
|
29 | 29 | parallel_tool_executor: Optional[ParallelToolExecutorInterface] = None,
|
30 |
| -) -> bool: |
| 30 | +) -> Generator[dict[str, Any], None, None]: |
31 | 31 | """Execute tools either in parallel or sequentially.
|
32 | 32 |
|
33 | 33 | Args:
|
34 | 34 | handler: Tool handler processing function.
|
35 | 35 | tool_uses: List of tool uses to execute.
|
36 | 36 | event_loop_metrics: Metrics collection object.
|
37 |
| - request_state: Current request state. |
38 | 37 | invalid_tool_use_ids: List of invalid tool use IDs.
|
39 | 38 | tool_results: List to populate with tool results.
|
40 | 39 | cycle_trace: Parent trace for the current cycle.
|
41 | 40 | parent_span: Parent span for the current cycle.
|
42 | 41 | parallel_tool_executor: Optional executor for parallel processing.
|
43 | 42 |
|
44 |
| - Returns: |
45 |
| - bool: True if any tool failed, False otherwise. |
| 43 | + Yields: |
| 44 | + Events of the tool invocations. Tool results are appended to `tool_results`. |
46 | 45 | """
|
47 | 46 |
|
48 |
| - def _handle_tool_execution(tool: ToolUse) -> Tuple[bool, Optional[ToolResult]]: |
49 |
| - result = None |
50 |
| - tool_succeeded = False |
51 |
| - |
| 47 | + def handle(tool: ToolUse) -> Generator[dict[str, Any], None, ToolResult]: |
52 | 48 | tracer = get_tracer()
|
53 | 49 | tool_call_span = tracer.start_tool_call_span(tool, parent_span)
|
54 | 50 |
|
55 |
| - try: |
56 |
| - if "toolUseId" not in tool or tool["toolUseId"] not in invalid_tool_use_ids: |
57 |
| - tool_name = tool["name"] |
58 |
| - tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) |
59 |
| - tool_start_time = time.time() |
60 |
| - result = handler(tool) |
61 |
| - tool_success = result.get("status") == "success" |
62 |
| - if tool_success: |
63 |
| - tool_succeeded = True |
64 |
| - |
65 |
| - tool_duration = time.time() - tool_start_time |
66 |
| - message = Message(role="user", content=[{"toolResult": result}]) |
67 |
| - event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) |
68 |
| - cycle_trace.add_child(tool_trace) |
69 |
| - |
70 |
| - if tool_call_span: |
71 |
| - tracer.end_tool_call_span(tool_call_span, result) |
72 |
| - except Exception as e: |
73 |
| - if tool_call_span: |
74 |
| - tracer.end_span_with_error(tool_call_span, str(e), e) |
75 |
| - |
76 |
| - return tool_succeeded, result |
77 |
| - |
78 |
| - any_tool_failed = False |
| 51 | + tool_name = tool["name"] |
| 52 | + tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) |
| 53 | + tool_start_time = time.time() |
| 54 | + |
| 55 | + result = handler(tool) |
| 56 | + yield {"result": result} # Placeholder until handler becomes a generator from which we can yield from |
| 57 | + |
| 58 | + tool_success = result.get("status") == "success" |
| 59 | + tool_duration = time.time() - tool_start_time |
| 60 | + message = Message(role="user", content=[{"toolResult": result}]) |
| 61 | + event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) |
| 62 | + cycle_trace.add_child(tool_trace) |
| 63 | + |
| 64 | + if tool_call_span: |
| 65 | + tracer.end_tool_call_span(tool_call_span, result) |
| 66 | + |
| 67 | + return result |
| 68 | + |
| 69 | + def work( |
| 70 | + tool: ToolUse, |
| 71 | + worker_id: int, |
| 72 | + worker_queue: queue.Queue, |
| 73 | + worker_event: threading.Event, |
| 74 | + ) -> ToolResult: |
| 75 | + events = handle(tool) |
| 76 | + |
| 77 | + while True: |
| 78 | + try: |
| 79 | + event = next(events) |
| 80 | + worker_queue.put((worker_id, event)) |
| 81 | + worker_event.wait() |
| 82 | + |
| 83 | + except StopIteration as stop: |
| 84 | + return cast(ToolResult, stop.value) |
| 85 | + |
| 86 | + tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] |
| 87 | + |
79 | 88 | if parallel_tool_executor:
|
80 | 89 | logger.debug(
|
81 | 90 | "tool_count=<%s>, tool_executor=<%s> | executing tools in parallel",
|
82 | 91 | len(tool_uses),
|
83 | 92 | type(parallel_tool_executor).__name__,
|
84 | 93 | )
|
85 |
| - # Submit all tasks with their associated tools |
86 |
| - future_to_tool = { |
87 |
| - parallel_tool_executor.submit(_handle_tool_execution, tool_use): tool_use for tool_use in tool_uses |
88 |
| - } |
| 94 | + |
| 95 | + worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue() |
| 96 | + worker_events = [threading.Event() for _ in range(len(tool_uses))] |
| 97 | + |
| 98 | + workers = [ |
| 99 | + parallel_tool_executor.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id]) |
| 100 | + for worker_id, tool_use in enumerate(tool_uses) |
| 101 | + ] |
89 | 102 | logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses))
|
90 | 103 |
|
91 |
| - # Collect results truly in parallel using the provided executor's as_completed method |
92 |
| - completed_results = [] |
93 |
| - try: |
94 |
| - for future in parallel_tool_executor.as_completed(future_to_tool): |
95 |
| - try: |
96 |
| - succeeded, result = future.result() |
97 |
| - if result is not None: |
98 |
| - completed_results.append(result) |
99 |
| - if not succeeded: |
100 |
| - any_tool_failed = True |
101 |
| - except Exception as e: |
102 |
| - tool = future_to_tool[future] |
103 |
| - logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], e) |
104 |
| - any_tool_failed = True |
105 |
| - except TimeoutError: |
106 |
| - logger.error("timeout_seconds=<%s> | parallel tool execution timed out", parallel_tool_executor.timeout) |
107 |
| - # Process any completed tasks |
108 |
| - for future in future_to_tool: |
109 |
| - if future.done(): # type: ignore |
110 |
| - try: |
111 |
| - succeeded, result = future.result(timeout=0) |
112 |
| - if result is not None: |
113 |
| - completed_results.append(result) |
114 |
| - except Exception as tool_e: |
115 |
| - tool = future_to_tool[future] |
116 |
| - logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], tool_e) |
117 |
| - else: |
118 |
| - # This future didn't complete within the timeout |
119 |
| - tool = future_to_tool[future] |
120 |
| - logger.debug("tool_name=<%s> | tool execution timed out", tool["name"]) |
121 |
| - |
122 |
| - any_tool_failed = True |
123 |
| - |
124 |
| - # Add completed results to tool_results |
125 |
| - tool_results.extend(completed_results) |
| 104 | + while not all(worker.done() for worker in workers): |
| 105 | + if not worker_queue.empty(): |
| 106 | + worker_id, event = worker_queue.get() |
| 107 | + yield event |
| 108 | + worker_events[worker_id].set() |
| 109 | + |
| 110 | + tool_results.extend([worker.result() for worker in workers]) |
| 111 | + |
126 | 112 | else:
|
127 | 113 | # Sequential execution fallback
|
128 | 114 | for tool_use in tool_uses:
|
129 |
| - succeeded, result = _handle_tool_execution(tool_use) |
130 |
| - if result is not None: |
131 |
| - tool_results.append(result) |
132 |
| - if not succeeded: |
133 |
| - any_tool_failed = True |
134 |
| - |
135 |
| - return any_tool_failed |
| 115 | + result = yield from handle(tool_use) |
| 116 | + tool_results.append(result) |
136 | 117 |
|
137 | 118 |
|
138 | 119 | def validate_and_prepare_tools(
|
139 | 120 | message: Message,
|
140 |
| - tool_uses: List[ToolUse], |
141 |
| - tool_results: List[ToolResult], |
142 |
| - invalid_tool_use_ids: List[str], |
| 121 | + tool_uses: list[ToolUse], |
| 122 | + tool_results: list[ToolResult], |
| 123 | + invalid_tool_use_ids: list[str], |
143 | 124 | ) -> None:
|
144 | 125 | """Validate tool uses and prepare them for execution.
|
145 | 126 |
|
|
0 commit comments