69
69
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
70
70
71
71
72
- async def execute_output_function_with_span (
72
+ async def execute_traced_output_function (
73
73
function_schema : _function_schema .FunctionSchema ,
74
74
run_context : RunContext [AgentDepsT ],
75
75
args : dict [str , Any ] | Any ,
76
+ wrap_validation_errors : bool = True ,
76
77
) -> Any :
77
- """Execute a function call within a traced span, automatically recording the response."""
78
+ """Execute an output function within a traced span with error handling.
79
+
80
+ This function executes the output function within an OpenTelemetry span for observability,
81
+ automatically records the function response, and handles ModelRetry exceptions by converting
82
+ them to ToolRetryError when wrap_validation_errors is True.
83
+
84
+ Args:
85
+ function_schema: The function schema containing the function to execute
86
+ run_context: The current run context containing tracing and tool information
87
+ args: Arguments to pass to the function
88
+ wrap_validation_errors: If True, wrap ModelRetry exceptions in ToolRetryError
89
+
90
+ Returns:
91
+ The result of the function execution
92
+
93
+ Raises:
94
+ ToolRetryError: When wrap_validation_errors is True and a ModelRetry is caught
95
+ ModelRetry: When wrap_validation_errors is False and a ModelRetry occurs
96
+ """
78
97
# Set up span attributes
79
98
tool_name = run_context .tool_name or getattr (function_schema .function , '__name__' , 'output_function' )
80
99
attributes = {
@@ -96,7 +115,19 @@ async def execute_output_function_with_span(
96
115
)
97
116
98
117
with run_context .tracer .start_as_current_span ('running output function' , attributes = attributes ) as span :
99
- output = await function_schema .call (args , run_context )
118
+ try :
119
+ output = await function_schema .call (args , run_context )
120
+ except ModelRetry as r :
121
+ if wrap_validation_errors :
122
+ m = _messages .RetryPromptPart (
123
+ content = r .message ,
124
+ tool_name = run_context .tool_name ,
125
+ )
126
+ if run_context .tool_call_id :
127
+ m .tool_call_id = run_context .tool_call_id # pragma: no cover
128
+ raise ToolRetryError (m ) from r
129
+ else :
130
+ raise
100
131
101
132
# Record response if content inclusion is enabled
102
133
if run_context .trace_include_content and span .is_recording ():
@@ -663,16 +694,7 @@ async def process(
663
694
else :
664
695
raise
665
696
666
- try :
667
- output = await self .call (output , run_context )
668
- except ModelRetry as r :
669
- if wrap_validation_errors :
670
- m = _messages .RetryPromptPart (
671
- content = r .message ,
672
- )
673
- raise ToolRetryError (m ) from r
674
- else :
675
- raise # pragma: no cover
697
+ output = await self .call (output , run_context , wrap_validation_errors )
676
698
677
699
return output
678
700
@@ -691,12 +713,15 @@ async def call(
691
713
self ,
692
714
output : Any ,
693
715
run_context : RunContext [AgentDepsT ],
716
+ wrap_validation_errors : bool = True ,
694
717
):
695
718
if k := self .outer_typed_dict_key :
696
719
output = output [k ]
697
720
698
721
if self ._function_schema :
699
- output = await execute_output_function_with_span (self ._function_schema , run_context , output )
722
+ output = await execute_traced_output_function (
723
+ self ._function_schema , run_context , output , wrap_validation_errors
724
+ )
700
725
701
726
return output
702
727
@@ -856,16 +881,7 @@ async def process(
856
881
wrap_validation_errors : bool = True ,
857
882
) -> OutputDataT :
858
883
args = {self ._str_argument_name : data }
859
- try :
860
- output = await execute_output_function_with_span (self ._function_schema , run_context , args )
861
- except ModelRetry as r :
862
- if wrap_validation_errors :
863
- m = _messages .RetryPromptPart (
864
- content = r .message ,
865
- )
866
- raise ToolRetryError (m ) from r
867
- else :
868
- raise # pragma: no cover
884
+ output = await execute_traced_output_function (self ._function_schema , run_context , args , wrap_validation_errors )
869
885
870
886
return cast (OutputDataT , output )
871
887
@@ -975,7 +991,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
975
991
async def call_tool (
976
992
self , name : str , tool_args : dict [str , Any ], ctx : RunContext [AgentDepsT ], tool : ToolsetTool [AgentDepsT ]
977
993
) -> Any :
978
- output = await self .processors [name ].call (tool_args , ctx )
994
+ output = await self .processors [name ].call (tool_args , ctx , wrap_validation_errors = False )
979
995
for validator in self .output_validators :
980
996
output = await validator .validate (output , ctx , wrap_validation_errors = False )
981
997
return output
0 commit comments