Skip to content

Commit 9a563d0

Browse files
committed
Split out dynamic toolset handling
1 parent be47b36 commit 9a563d0

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,12 @@ def __init__(
426426
self._output_toolset.max_retries = self._max_result_retries
427427

428428
self._function_toolset = FunctionToolset(tools, max_retries=retries)
429-
self._user_toolsets = [
430-
toolset if isinstance(toolset, AbstractToolset) else DynamicToolset[AgentDepsT](toolset_func=toolset)
429+
self._dynamic_toolsets = [
430+
DynamicToolset[AgentDepsT](toolset_func=toolset)
431431
for toolset in toolsets or []
432+
if not isinstance(toolset, AbstractToolset)
432433
]
434+
self._user_toolsets = [toolset for toolset in toolsets or [] if isinstance(toolset, AbstractToolset)]
433435

434436
self.history_processors = history_processors or []
435437

@@ -1680,7 +1682,7 @@ async def simple_toolset(ctx: RunContext[str]) -> AbstractToolset[str]:
16801682
"""
16811683

16821684
def toolset_decorator(func_: ToolsetFunc[AgentDepsT]) -> ToolsetFunc[AgentDepsT]:
1683-
self._user_toolsets.append(DynamicToolset(func_, per_run_step=per_run_step))
1685+
self._dynamic_toolsets.append(DynamicToolset(func_, per_run_step=per_run_step))
16841686
return func_
16851687

16861688
return toolset_decorator if func is None else toolset_decorator(func)
@@ -1747,10 +1749,9 @@ def _get_toolset(
17471749
else:
17481750
user_toolsets = self._user_toolsets
17491751

1750-
dynamic_toolsets = [toolset.copy() for toolset in user_toolsets if isinstance(toolset, DynamicToolset)]
1751-
static_toolsets = [toolset for toolset in user_toolsets if not isinstance(toolset, DynamicToolset)]
1752+
dynamic_toolsets = [toolset.copy() for toolset in self._dynamic_toolsets]
17521753

1753-
all_toolsets = [self._function_toolset, *static_toolsets, *dynamic_toolsets]
1754+
all_toolsets = [self._function_toolset, *user_toolsets, *dynamic_toolsets]
17541755

17551756
if self._prepare_tools:
17561757
all_toolsets = [PreparedToolset(CombinedToolset(all_toolsets), self._prepare_tools)]

tests/test_toolsets.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pydantic_ai.messages import ToolCallPart
1717
from pydantic_ai.models.test import TestModel
1818
from pydantic_ai.tools import ToolDefinition
19-
from pydantic_ai.toolsets._dynamic import _DynamicToolset as DynamicToolset
19+
from pydantic_ai.toolsets._dynamic import _DynamicToolset as DynamicToolset # pyright: ignore[reportPrivateUsage]
2020
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
2121
from pydantic_ai.toolsets.combined import CombinedToolset
2222
from pydantic_ai.toolsets.filtered import FilteredToolset
@@ -500,6 +500,24 @@ async def test_context_manager_failed_initialization():
500500
assert server1.is_running is False
501501

502502

503+
async def test_tool_manager_reuse_self():
504+
"""Test the retry logic with failed_tools and for_run_step method."""
505+
506+
run_context = build_run_context(None, run_step=1)
507+
508+
tool_manager = ToolManager[None](run_context, FunctionToolset[None](), tools={})
509+
510+
same_tool_manager = await tool_manager.for_run_step(ctx=run_context)
511+
512+
assert tool_manager is same_tool_manager
513+
514+
step_2_context = build_run_context(None, run_step=2)
515+
516+
updated_tool_manager = await tool_manager.for_run_step(ctx=step_2_context)
517+
518+
assert tool_manager != updated_tool_manager
519+
520+
503521
async def test_tool_manager_retry_logic():
504522
"""Test the retry logic with failed_tools and for_run_step method."""
505523

@@ -654,7 +672,7 @@ async def get_tools(self, ctx: RunContext[None]) -> dict[str, ToolsetTool[None]]
654672
async def call_tool(
655673
self, name: str, tool_args: dict[str, Any], ctx: RunContext[None], tool: ToolsetTool[None]
656674
) -> Any:
657-
return None
675+
return None # pragma: no cover
658676

659677
def toolset_factory(ctx: RunContext[None]) -> AbstractToolset[None]:
660678
return EnterableToolset()
@@ -686,8 +704,8 @@ def visitor(toolset: AbstractToolset[None]) -> None:
686704

687705
assert tools == {}
688706

689-
async def test_dynamic_toolset_empty():
690707

708+
async def test_dynamic_toolset_empty():
691709
def no_toolset_func(ctx: RunContext[None]) -> None:
692710
return None
693711

@@ -697,4 +715,4 @@ def no_toolset_func(ctx: RunContext[None]) -> None:
697715

698716
tools = await toolset.get_tools(run_context)
699717

700-
assert tools == {}
718+
assert tools == {}

0 commit comments

Comments
 (0)