Skip to content

Commit d49e51a

Browse files
committed
Split out dynamic toolset handling
1 parent be47b36 commit d49e51a

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,12 @@ def __init__(
426426
self._output_toolset.max_retries = self._max_result_retries
427427

428428
self._function_toolset = FunctionToolset(tools, max_retries=retries)
429-
self._user_toolsets = [
430-
toolset if isinstance(toolset, AbstractToolset) else DynamicToolset[AgentDepsT](toolset_func=toolset)
429+
self._dynamic_toolsets = [
430+
DynamicToolset[AgentDepsT](toolset_func=toolset)
431431
for toolset in toolsets or []
432+
if not isinstance(toolset, AbstractToolset)
432433
]
434+
self._user_toolsets = [toolset for toolset in toolsets or [] if isinstance(toolset, AbstractToolset)]
433435

434436
self.history_processors = history_processors or []
435437

@@ -1680,7 +1682,7 @@ async def simple_toolset(ctx: RunContext[str]) -> AbstractToolset[str]:
16801682
"""
16811683

16821684
def toolset_decorator(func_: ToolsetFunc[AgentDepsT]) -> ToolsetFunc[AgentDepsT]:
1683-
self._user_toolsets.append(DynamicToolset(func_, per_run_step=per_run_step))
1685+
self._dynamic_toolsets.append(DynamicToolset(func_, per_run_step=per_run_step))
16841686
return func_
16851687

16861688
return toolset_decorator if func is None else toolset_decorator(func)
@@ -1747,10 +1749,9 @@ def _get_toolset(
17471749
else:
17481750
user_toolsets = self._user_toolsets
17491751

1750-
dynamic_toolsets = [toolset.copy() for toolset in user_toolsets if isinstance(toolset, DynamicToolset)]
1751-
static_toolsets = [toolset for toolset in user_toolsets if not isinstance(toolset, DynamicToolset)]
1752+
dynamic_toolsets = [toolset.copy() for toolset in self._dynamic_toolsets]
17521753

1753-
all_toolsets = [self._function_toolset, *static_toolsets, *dynamic_toolsets]
1754+
all_toolsets = [self._function_toolset, *user_toolsets, *dynamic_toolsets]
17541755

17551756
if self._prepare_tools:
17561757
all_toolsets = [PreparedToolset(CombinedToolset(all_toolsets), self._prepare_tools)]

pydantic_ai_slim/pydantic_ai/toolsets/_dynamic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@ def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
7373
self._toolset.apply(visitor)
7474

7575
def copy(self) -> _DynamicToolset[AgentDepsT]:
76-
return _DynamicToolset(toolset_func=self.toolset_func, per_run_step=self.per_run_step)
76+
return _DynamicToolset(toolset_func=self.toolset_func, per_run_step=self.per_run_step)

tests/test_toolsets.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pydantic_ai.messages import ToolCallPart
1717
from pydantic_ai.models.test import TestModel
1818
from pydantic_ai.tools import ToolDefinition
19-
from pydantic_ai.toolsets._dynamic import _DynamicToolset as DynamicToolset
19+
from pydantic_ai.toolsets._dynamic import _DynamicToolset as DynamicToolset # pyright: ignore[reportPrivateUsage]
2020
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
2121
from pydantic_ai.toolsets.combined import CombinedToolset
2222
from pydantic_ai.toolsets.filtered import FilteredToolset
@@ -500,6 +500,25 @@ async def test_context_manager_failed_initialization():
500500
assert server1.is_running is False
501501

502502

503+
504+
async def test_tool_manager_reuse_self():
505+
"""Test the retry logic with failed_tools and for_run_step method."""
506+
507+
run_context = build_run_context(None, run_step=1)
508+
509+
tool_manager = ToolManager[None](run_context, FunctionToolset[None](), tools={})
510+
511+
same_tool_manager = await tool_manager.for_run_step(ctx=run_context)
512+
513+
assert tool_manager is same_tool_manager
514+
515+
step_2_context = build_run_context(None, run_step=2)
516+
517+
updated_tool_manager = await tool_manager.for_run_step(ctx=step_2_context)
518+
519+
assert tool_manager != updated_tool_manager
520+
521+
503522
async def test_tool_manager_retry_logic():
504523
"""Test the retry logic with failed_tools and for_run_step method."""
505524

@@ -654,7 +673,7 @@ async def get_tools(self, ctx: RunContext[None]) -> dict[str, ToolsetTool[None]]
654673
async def call_tool(
655674
self, name: str, tool_args: dict[str, Any], ctx: RunContext[None], tool: ToolsetTool[None]
656675
) -> Any:
657-
return None
676+
return None # pragma: no cover
658677

659678
def toolset_factory(ctx: RunContext[None]) -> AbstractToolset[None]:
660679
return EnterableToolset()
@@ -686,8 +705,8 @@ def visitor(toolset: AbstractToolset[None]) -> None:
686705

687706
assert tools == {}
688707

689-
async def test_dynamic_toolset_empty():
690708

709+
async def test_dynamic_toolset_empty():
691710
def no_toolset_func(ctx: RunContext[None]) -> None:
692711
return None
693712

@@ -697,4 +716,4 @@ def no_toolset_func(ctx: RunContext[None]) -> None:
697716

698717
tools = await toolset.get_tools(run_context)
699718

700-
assert tools == {}
719+
assert tools == {}

0 commit comments

Comments
 (0)