Skip to content

Commit d601615

Browse files
authored
executor - run tools - yield (#328)
1 parent dacdf10 commit d601615

File tree

4 files changed

+87
-210
lines changed

4 files changed

+87
-210
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import time
1313
import uuid
1414
from functools import partial
15-
from typing import Any, Generator, Optional, cast
15+
from typing import Any, Generator, Optional
1616

1717
from opentelemetry import trace
1818

@@ -369,11 +369,10 @@ def _handle_tool_execution(
369369
kwargs=kwargs,
370370
)
371371

372-
run_tools(
372+
yield from run_tools(
373373
handler=tool_handler_process,
374374
tool_uses=tool_uses,
375375
event_loop_metrics=event_loop_metrics,
376-
request_state=cast(Any, kwargs["request_state"]),
377376
invalid_tool_use_ids=invalid_tool_use_ids,
378377
tool_results=tool_results,
379378
cycle_trace=cycle_trace,

src/strands/tools/executor.py

Lines changed: 68 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Tool execution functionality for the event loop."""
22

33
import logging
4+
import queue
5+
import threading
46
import time
5-
from concurrent.futures import TimeoutError
6-
from typing import Any, Callable, List, Optional, Tuple
7+
from typing import Any, Callable, Generator, Optional, cast
78

89
from opentelemetry import trace
910

@@ -19,127 +20,107 @@
1920

2021
def run_tools(
2122
handler: Callable[[ToolUse], ToolResult],
22-
tool_uses: List[ToolUse],
23+
tool_uses: list[ToolUse],
2324
event_loop_metrics: EventLoopMetrics,
24-
request_state: Any,
25-
invalid_tool_use_ids: List[str],
26-
tool_results: List[ToolResult],
25+
invalid_tool_use_ids: list[str],
26+
tool_results: list[ToolResult],
2727
cycle_trace: Trace,
2828
parent_span: Optional[trace.Span] = None,
2929
parallel_tool_executor: Optional[ParallelToolExecutorInterface] = None,
30-
) -> bool:
30+
) -> Generator[dict[str, Any], None, None]:
3131
"""Execute tools either in parallel or sequentially.
3232
3333
Args:
3434
handler: Tool handler processing function.
3535
tool_uses: List of tool uses to execute.
3636
event_loop_metrics: Metrics collection object.
37-
request_state: Current request state.
3837
invalid_tool_use_ids: List of invalid tool use IDs.
3938
tool_results: List to populate with tool results.
4039
cycle_trace: Parent trace for the current cycle.
4140
parent_span: Parent span for the current cycle.
4241
parallel_tool_executor: Optional executor for parallel processing.
4342
44-
Returns:
45-
bool: True if any tool failed, False otherwise.
43+
Yields:
44+
Events of the tool invocations. Tool results are appended to `tool_results`.
4645
"""
4746

48-
def _handle_tool_execution(tool: ToolUse) -> Tuple[bool, Optional[ToolResult]]:
49-
result = None
50-
tool_succeeded = False
51-
47+
def handle(tool: ToolUse) -> Generator[dict[str, Any], None, ToolResult]:
5248
tracer = get_tracer()
5349
tool_call_span = tracer.start_tool_call_span(tool, parent_span)
5450

55-
try:
56-
if "toolUseId" not in tool or tool["toolUseId"] not in invalid_tool_use_ids:
57-
tool_name = tool["name"]
58-
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
59-
tool_start_time = time.time()
60-
result = handler(tool)
61-
tool_success = result.get("status") == "success"
62-
if tool_success:
63-
tool_succeeded = True
64-
65-
tool_duration = time.time() - tool_start_time
66-
message = Message(role="user", content=[{"toolResult": result}])
67-
event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message)
68-
cycle_trace.add_child(tool_trace)
69-
70-
if tool_call_span:
71-
tracer.end_tool_call_span(tool_call_span, result)
72-
except Exception as e:
73-
if tool_call_span:
74-
tracer.end_span_with_error(tool_call_span, str(e), e)
75-
76-
return tool_succeeded, result
77-
78-
any_tool_failed = False
51+
tool_name = tool["name"]
52+
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
53+
tool_start_time = time.time()
54+
55+
result = handler(tool)
56+
yield {"result": result} # Placeholder until handler becomes a generator from which we can yield from
57+
58+
tool_success = result.get("status") == "success"
59+
tool_duration = time.time() - tool_start_time
60+
message = Message(role="user", content=[{"toolResult": result}])
61+
event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message)
62+
cycle_trace.add_child(tool_trace)
63+
64+
if tool_call_span:
65+
tracer.end_tool_call_span(tool_call_span, result)
66+
67+
return result
68+
69+
def work(
70+
tool: ToolUse,
71+
worker_id: int,
72+
worker_queue: queue.Queue,
73+
worker_event: threading.Event,
74+
) -> ToolResult:
75+
events = handle(tool)
76+
77+
while True:
78+
try:
79+
event = next(events)
80+
worker_queue.put((worker_id, event))
81+
worker_event.wait()
82+
83+
except StopIteration as stop:
84+
return cast(ToolResult, stop.value)
85+
86+
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
87+
7988
if parallel_tool_executor:
8089
logger.debug(
8190
"tool_count=<%s>, tool_executor=<%s> | executing tools in parallel",
8291
len(tool_uses),
8392
type(parallel_tool_executor).__name__,
8493
)
85-
# Submit all tasks with their associated tools
86-
future_to_tool = {
87-
parallel_tool_executor.submit(_handle_tool_execution, tool_use): tool_use for tool_use in tool_uses
88-
}
94+
95+
worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue()
96+
worker_events = [threading.Event() for _ in range(len(tool_uses))]
97+
98+
workers = [
99+
parallel_tool_executor.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id])
100+
for worker_id, tool_use in enumerate(tool_uses)
101+
]
89102
logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses))
90103

91-
# Collect results truly in parallel using the provided executor's as_completed method
92-
completed_results = []
93-
try:
94-
for future in parallel_tool_executor.as_completed(future_to_tool):
95-
try:
96-
succeeded, result = future.result()
97-
if result is not None:
98-
completed_results.append(result)
99-
if not succeeded:
100-
any_tool_failed = True
101-
except Exception as e:
102-
tool = future_to_tool[future]
103-
logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], e)
104-
any_tool_failed = True
105-
except TimeoutError:
106-
logger.error("timeout_seconds=<%s> | parallel tool execution timed out", parallel_tool_executor.timeout)
107-
# Process any completed tasks
108-
for future in future_to_tool:
109-
if future.done(): # type: ignore
110-
try:
111-
succeeded, result = future.result(timeout=0)
112-
if result is not None:
113-
completed_results.append(result)
114-
except Exception as tool_e:
115-
tool = future_to_tool[future]
116-
logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], tool_e)
117-
else:
118-
# This future didn't complete within the timeout
119-
tool = future_to_tool[future]
120-
logger.debug("tool_name=<%s> | tool execution timed out", tool["name"])
121-
122-
any_tool_failed = True
123-
124-
# Add completed results to tool_results
125-
tool_results.extend(completed_results)
104+
while not all(worker.done() for worker in workers):
105+
if not worker_queue.empty():
106+
worker_id, event = worker_queue.get()
107+
yield event
108+
worker_events[worker_id].set()
109+
110+
tool_results.extend([worker.result() for worker in workers])
111+
126112
else:
127113
# Sequential execution fallback
128114
for tool_use in tool_uses:
129-
succeeded, result = _handle_tool_execution(tool_use)
130-
if result is not None:
131-
tool_results.append(result)
132-
if not succeeded:
133-
any_tool_failed = True
134-
135-
return any_tool_failed
115+
result = yield from handle(tool_use)
116+
tool_results.append(result)
136117

137118

138119
def validate_and_prepare_tools(
139120
message: Message,
140-
tool_uses: List[ToolUse],
141-
tool_results: List[ToolResult],
142-
invalid_tool_use_ids: List[str],
121+
tool_uses: list[ToolUse],
122+
tool_results: list[ToolResult],
123+
invalid_tool_use_ids: list[str],
143124
) -> None:
144125
"""Validate tool uses and prepare them for execution.
145126

src/strands/types/event_loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def result(self, timeout: Optional[int] = None) -> Any:
6565
Any: The result of the asynchronous operation.
6666
"""
6767

68+
def done(self) -> bool:
69+
"""Returns true if future is done executing."""
70+
6871

6972
@runtime_checkable
7073
class ParallelToolExecutorInterface(Protocol):

0 commit comments

Comments
 (0)