Skip to content

Commit 622ff39

Browse files
bitnahianDouweM
authored andcommitted
chore: simplify output function call with model retry (pydantic#2273)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 798b3cd commit 622ff39

File tree

1 file changed

+41
-25
lines changed

1 file changed

+41
-25
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,31 @@
6969
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
7070

7171

72-
async def execute_output_function_with_span(
72+
async def execute_traced_output_function(
7373
function_schema: _function_schema.FunctionSchema,
7474
run_context: RunContext[AgentDepsT],
7575
args: dict[str, Any] | Any,
76+
wrap_validation_errors: bool = True,
7677
) -> Any:
77-
"""Execute a function call within a traced span, automatically recording the response."""
78+
"""Execute an output function within a traced span with error handling.
79+
80+
This function executes the output function within an OpenTelemetry span for observability,
81+
automatically records the function response, and handles ModelRetry exceptions by converting
82+
them to ToolRetryError when wrap_validation_errors is True.
83+
84+
Args:
85+
function_schema: The function schema containing the function to execute
86+
run_context: The current run context containing tracing and tool information
87+
args: Arguments to pass to the function
88+
wrap_validation_errors: If True, wrap ModelRetry exceptions in ToolRetryError
89+
90+
Returns:
91+
The result of the function execution
92+
93+
Raises:
94+
ToolRetryError: When wrap_validation_errors is True and a ModelRetry is caught
95+
ModelRetry: When wrap_validation_errors is False and a ModelRetry occurs
96+
"""
7897
# Set up span attributes
7998
tool_name = run_context.tool_name or getattr(function_schema.function, '__name__', 'output_function')
8099
attributes = {
@@ -96,7 +115,19 @@ async def execute_output_function_with_span(
96115
)
97116

98117
with run_context.tracer.start_as_current_span('running output function', attributes=attributes) as span:
99-
output = await function_schema.call(args, run_context)
118+
try:
119+
output = await function_schema.call(args, run_context)
120+
except ModelRetry as r:
121+
if wrap_validation_errors:
122+
m = _messages.RetryPromptPart(
123+
content=r.message,
124+
tool_name=run_context.tool_name,
125+
)
126+
if run_context.tool_call_id:
127+
m.tool_call_id = run_context.tool_call_id # pragma: no cover
128+
raise ToolRetryError(m) from r
129+
else:
130+
raise
100131

101132
# Record response if content inclusion is enabled
102133
if run_context.trace_include_content and span.is_recording():
@@ -663,16 +694,7 @@ async def process(
663694
else:
664695
raise
665696

666-
try:
667-
output = await self.call(output, run_context)
668-
except ModelRetry as r:
669-
if wrap_validation_errors:
670-
m = _messages.RetryPromptPart(
671-
content=r.message,
672-
)
673-
raise ToolRetryError(m) from r
674-
else:
675-
raise # pragma: no cover
697+
output = await self.call(output, run_context, wrap_validation_errors)
676698

677699
return output
678700

@@ -691,12 +713,15 @@ async def call(
691713
self,
692714
output: Any,
693715
run_context: RunContext[AgentDepsT],
716+
wrap_validation_errors: bool = True,
694717
):
695718
if k := self.outer_typed_dict_key:
696719
output = output[k]
697720

698721
if self._function_schema:
699-
output = await execute_output_function_with_span(self._function_schema, run_context, output)
722+
output = await execute_traced_output_function(
723+
self._function_schema, run_context, output, wrap_validation_errors
724+
)
700725

701726
return output
702727

@@ -856,16 +881,7 @@ async def process(
856881
wrap_validation_errors: bool = True,
857882
) -> OutputDataT:
858883
args = {self._str_argument_name: data}
859-
try:
860-
output = await execute_output_function_with_span(self._function_schema, run_context, args)
861-
except ModelRetry as r:
862-
if wrap_validation_errors:
863-
m = _messages.RetryPromptPart(
864-
content=r.message,
865-
)
866-
raise ToolRetryError(m) from r
867-
else:
868-
raise # pragma: no cover
884+
output = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors)
869885

870886
return cast(OutputDataT, output)
871887

@@ -975,7 +991,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
975991
async def call_tool(
976992
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
977993
) -> Any:
978-
output = await self.processors[name].call(tool_args, ctx)
994+
output = await self.processors[name].call(tool_args, ctx, wrap_validation_errors=False)
979995
for validator in self.output_validators:
980996
output = await validator.validate(output, ctx, wrap_validation_errors=False)
981997
return output

0 commit comments

Comments
 (0)