Skip to content

Commit d691f95

Browse files
committed
Don't call toolset.get_tools again when tool manager is built for same run step
1 parent c9852dc commit d691f95

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[Agent
4141

4242
async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
4343
"""Build a new tool manager for the next run step, carrying over the retries from the current run step."""
44+
if ctx.run_step == self.ctx.run_step:
45+
return self
46+
4447
retries = {
4548
failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1 for failed_tool_name in self.failed_tools
4649
}

tests/test_toolsets.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@
2727
T = TypeVar('T')
2828

2929

30-
def build_run_context(deps: T) -> RunContext[T]:
30+
def build_run_context(deps: T, run_step: int = 0) -> RunContext[T]:
3131
return RunContext(
3232
deps=deps,
3333
model=TestModel(),
3434
usage=Usage(),
3535
prompt=None,
3636
messages=[],
37-
run_step=0,
37+
run_step=run_step,
3838
)
3939

4040

@@ -542,7 +542,7 @@ def other_tool(x: int) -> int:
542542
assert call_count['other_tool'] == 1
543543

544544
# Test for_run_step - should create new tool manager with updated retry counts
545-
new_context = build_run_context(TestDeps())
545+
new_context = build_run_context(TestDeps(), run_step=1)
546546
new_tool_manager = await tool_manager.for_run_step(new_context)
547547

548548
# The new tool manager should have retry count for the failed tool
@@ -565,7 +565,7 @@ def other_tool(x: int) -> int:
565565
assert call_count['failing_tool'] == 4
566566

567567
# Create another run step
568-
another_context = build_run_context(TestDeps())
568+
another_context = build_run_context(TestDeps(), run_step=2)
569569
another_tool_manager = await new_tool_manager.for_run_step(another_context)
570570

571571
# Should now have retry count of 2 for failing_tool
@@ -621,7 +621,7 @@ def tool_c(x: int) -> int:
621621
assert tool_manager.failed_tools == {'tool_a', 'tool_b'} # unchanged
622622

623623
# Create next run step - should have retry counts for both failed tools
624-
new_context = build_run_context(TestDeps())
624+
new_context = build_run_context(TestDeps(), run_step=1)
625625
new_tool_manager = await tool_manager.for_run_step(new_context)
626626

627627
assert new_tool_manager.ctx.retries == {'tool_a': 1, 'tool_b': 1}

0 commit comments

Comments
 (0)