11from __future__ import annotations
22
33import asyncio
4+ import contextlib
45import dataclasses
56import inspect
67from collections .abc import Awaitable
@@ -226,6 +227,29 @@ def get_model_tracing_impl(
226227 return ModelTracing .ENABLED_WITHOUT_DATA
227228
228229
230+ # Helpers for cancellable tool execution
231+
232+
233+ async def _await_cancellable (awaitable ):
234+ """Await an awaitable in its own task so CancelledError interrupts promptly."""
235+ task = asyncio .create_task (awaitable )
236+ try :
237+ return await task
238+ except asyncio .CancelledError :
239+ # propagate so run.py can handle terminal cancel
240+ raise
241+
242+
243+ def _maybe_call_cancel_hook (tool_obj ) -> None :
244+ """Best-effort: call a cancel/terminate hook on the tool if present."""
245+ for name in ("cancel" , "terminate" , "stop" ):
246+ cb = getattr (tool_obj , name , None )
247+ if callable (cb ):
248+ with contextlib .suppress (Exception ):
249+ cb ()
250+ break
251+
252+
229253class RunImpl :
230254 @classmethod
231255 async def execute_tools_and_side_effects (
@@ -556,16 +580,24 @@ async def run_single_tool(
556580 if config .trace_include_sensitive_data :
557581 span_fn .span_data .input = tool_call .arguments
558582 try :
559- _ , _ , result = await asyncio .gather (
583+ # run start hooks first (don’t tie them to the cancellable task)
584+ await asyncio .gather (
560585 hooks .on_tool_start (tool_context , agent , func_tool ),
561586 (
562587 agent .hooks .on_tool_start (tool_context , agent , func_tool )
563588 if agent .hooks
564589 else _coro .noop_coroutine ()
565590 ),
566- func_tool .on_invoke_tool (tool_context , tool_call .arguments ),
567591 )
568592
593+ try :
594+ result = await _await_cancellable (
595+ func_tool .on_invoke_tool (tool_context , tool_call .arguments )
596+ )
597+ except asyncio .CancelledError :
598+ _maybe_call_cancel_hook (func_tool )
599+ raise
600+
569601 await asyncio .gather (
570602 hooks .on_tool_end (tool_context , agent , func_tool , result ),
571603 (
@@ -574,6 +606,7 @@ async def run_single_tool(
574606 else _coro .noop_coroutine ()
575607 ),
576608 )
609+
577610 except Exception as e :
578611 _error_tracing .attach_error_to_current_span (
579612 SpanError (
@@ -644,7 +677,6 @@ async def execute_computer_actions(
644677 config : RunConfig ,
645678 ) -> list [RunItem ]:
646679 results : list [RunItem ] = []
647- # Need to run these serially, because each action can affect the computer state
648680 for action in actions :
649681 acknowledged : list [ComputerCallOutputAcknowledgedSafetyCheck ] | None = None
650682 if action .tool_call .pending_safety_checks and action .computer_tool .on_safety_check :
@@ -661,24 +693,28 @@ async def execute_computer_actions(
661693 if ack :
662694 acknowledged .append (
663695 ComputerCallOutputAcknowledgedSafetyCheck (
664- id = check .id ,
665- code = check .code ,
666- message = check .message ,
696+ id = check .id , code = check .code , message = check .message
667697 )
668698 )
669699 else :
670700 raise UserError ("Computer tool safety check was not acknowledged" )
671701
672- results .append (
673- await ComputerAction .execute (
674- agent = agent ,
675- action = action ,
676- hooks = hooks ,
677- context_wrapper = context_wrapper ,
678- config = config ,
679- acknowledged_safety_checks = acknowledged ,
702+ try :
703+ item = await _await_cancellable (
704+ ComputerAction .execute (
705+ agent = agent ,
706+ action = action ,
707+ hooks = hooks ,
708+ context_wrapper = context_wrapper ,
709+ config = config ,
710+ acknowledged_safety_checks = acknowledged ,
711+ )
680712 )
681- )
713+ except asyncio .CancelledError :
714+ _maybe_call_cancel_hook (action .computer_tool )
715+ raise
716+
717+ results .append (item )
682718
683719 return results
684720
@@ -1052,16 +1088,23 @@ async def execute(
10521088 else cls ._get_screenshot_sync (action .computer_tool .computer , action .tool_call )
10531089 )
10541090
1055- _ , _ , output = await asyncio .gather (
1091+ # start hooks first
1092+ await asyncio .gather (
10561093 hooks .on_tool_start (context_wrapper , agent , action .computer_tool ),
10571094 (
10581095 agent .hooks .on_tool_start (context_wrapper , agent , action .computer_tool )
10591096 if agent .hooks
10601097 else _coro .noop_coroutine ()
10611098 ),
1062- output_func ,
10631099 )
1064-
1100+ # run the action (screenshot/etc) in a cancellable task
1101+ try :
1102+ output = await _await_cancellable (output_func )
1103+ except asyncio .CancelledError :
1104+ _maybe_call_cancel_hook (action .computer_tool )
1105+ raise
1106+
1107+ # end hooks
10651108 await asyncio .gather (
10661109 hooks .on_tool_end (context_wrapper , agent , action .computer_tool , output ),
10671110 (
@@ -1169,10 +1212,20 @@ async def execute(
11691212 data = call .tool_call ,
11701213 )
11711214 output = call .local_shell_tool .executor (request )
1172- if inspect .isawaitable (output ):
1173- result = await output
1174- else :
1175- result = output
1215+ try :
1216+ if inspect .isawaitable (output ):
1217+ result = await _await_cancellable (output )
1218+ else :
1219+ # If executor returns a sync result, just use it (can’t cancel mid-call)
1220+ result = output
1221+ except asyncio .CancelledError :
1222+ # Best-effort: if the executor or tool exposes a cancel/terminate, call it
1223+ _maybe_call_cancel_hook (call .local_shell_tool )
1224+ # If your executor returns a proc handle (common pattern), adddress it here if needed:
1225+ # with contextlib.suppress(Exception):
1226+ # proc.terminate(); await asyncio.wait_for(proc.wait(), 1.0)
1227+ # proc.kill()
1228+ raise
11761229
11771230 await asyncio .gather (
11781231 hooks .on_tool_end (context_wrapper , agent , call .local_shell_tool , result ),
@@ -1185,7 +1238,7 @@ async def execute(
11851238
11861239 return ToolCallOutputItem (
11871240 agent = agent ,
1188- output = output ,
1241+ output = result ,
11891242 raw_item = {
11901243 "type" : "local_shell_call_output" ,
11911244 "id" : call .tool_call .call_id ,
0 commit comments