16
16
from pydantic .json_schema import GenerateJsonSchema
17
17
from typing_extensions import Literal , Never , Self , TypeIs , TypeVar , deprecated
18
18
19
+ from pydantic_ai .toolsets ._dynamic import DynamicToolset
19
20
from pydantic_graph import End , Graph , GraphRun , GraphRunContext
20
21
from pydantic_graph ._utils import get_event_loop
21
22
@@ -164,8 +165,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
164
165
)
165
166
_function_toolset : FunctionToolset [AgentDepsT ] = dataclasses .field (repr = False )
166
167
_output_toolset : OutputToolset [AgentDepsT ] | None = dataclasses .field (repr = False )
167
- _user_toolsets : Sequence [AbstractToolset [AgentDepsT ]] = dataclasses .field (repr = False )
168
- _toolset_functions : Sequence [ToolsetFunc [AgentDepsT ]] = dataclasses .field (repr = False )
168
+ _user_toolsets : list [AbstractToolset [AgentDepsT ]] = dataclasses .field (repr = False )
169
169
_prepare_tools : ToolsPrepareFunc [AgentDepsT ] | None = dataclasses .field (repr = False )
170
170
_prepare_output_tools : ToolsPrepareFunc [AgentDepsT ] | None = dataclasses .field (repr = False )
171
171
_max_result_retries : int = dataclasses .field (repr = False )
@@ -422,8 +422,9 @@ def __init__(
422
422
self ._output_toolset .max_retries = self ._max_result_retries
423
423
424
424
self ._function_toolset = FunctionToolset (tools , max_retries = retries )
425
- self ._user_toolsets = [toolset for toolset in toolsets or [] if isinstance (toolset , AbstractToolset )]
426
- self ._toolset_functions = [toolset for toolset in toolsets or [] if not isinstance (toolset , AbstractToolset )]
425
+ self ._user_toolsets = [
426
+ toolset if isinstance (toolset , AbstractToolset ) else DynamicToolset (toolset ) for toolset in toolsets or []
427
+ ]
427
428
428
429
self .history_processors = history_processors or []
429
430
@@ -774,11 +775,9 @@ async def main():
774
775
run_step = state .run_step ,
775
776
)
776
777
777
- toolsets_from_functions = await self ._materialize_toolset_functions (run_context )
778
-
779
778
toolset = self ._get_toolset (
780
779
output_toolset = output_toolset ,
781
- additional_toolsets = [ * ( toolsets or []), * toolsets_from_functions ] ,
780
+ additional_toolsets = toolsets ,
782
781
)
783
782
784
783
# This will raise errors for any name conflicts
@@ -1632,11 +1631,24 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams
1632
1631
1633
1632
return tool_decorator if func is None else tool_decorator (func )
1634
1633
1634
+ @overload
1635
+ def toolset (self , func : ToolsetFunc [AgentDepsT ], / ) -> ToolsetFunc [AgentDepsT ]: ...
1636
+
1637
+ @overload
1635
1638
def toolset (
1636
1639
self ,
1637
- func : ToolsetFunc [AgentDepsT ],
1638
1640
/ ,
1639
- ) -> Callable [[ToolsetFunc [AgentDepsT ]], ToolsetFunc [AgentDepsT ]] | ToolsetFunc [AgentDepsT ]:
1641
+ * ,
1642
+ per_run_step : bool = True ,
1643
+ ) -> Callable [[ToolsetFunc [AgentDepsT ]], ToolsetFunc [AgentDepsT ]]: ...
1644
+
1645
+ def toolset (
1646
+ self ,
1647
+ func : ToolsetFunc [AgentDepsT ] | None = None ,
1648
+ / ,
1649
+ * ,
1650
+ per_run_step : bool = True ,
1651
+ ) -> Any :
1640
1652
"""Decorator to register a toolset function.
1641
1653
1642
1654
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
@@ -1656,9 +1668,17 @@ async def simple_toolset(ctx: RunContext[str]) -> AbstractToolset[str]:
1656
1668
return FunctionToolset()
1657
1669
1658
1670
```
1671
+
1672
+ Args:
1673
+ func: The toolset function to register.
1674
+ per_run_step: Whether to re-evaluate the toolset for each run step. Defaults to True.
1659
1675
"""
1660
- self ._toolset_functions = [* self ._toolset_functions , func ]
1661
- return func
1676
+
1677
+ def toolset_decorator (func_ : ToolsetFunc [AgentDepsT ]) -> ToolsetFunc [AgentDepsT ]:
1678
+ self ._user_toolsets .append (DynamicToolset (func_ , per_run_step = per_run_step ))
1679
+ return func_
1680
+
1681
+ return toolset_decorator if func is None else toolset_decorator (func )
1662
1682
1663
1683
def _get_model (self , model : models .Model | models .KnownModelName | str | None ) -> models .Model :
1664
1684
"""Create a model configured for this agent.
@@ -1780,20 +1800,6 @@ def _prepare_output_schema(
1780
1800
1781
1801
return schema # pyright: ignore[reportReturnType]
1782
1802
1783
- async def _materialize_toolset_functions (
1784
- self , run_context : RunContext [AgentDepsT ]
1785
- ) -> list [AbstractToolset [AgentDepsT ]]:
1786
- materialized_toolsets : list [AbstractToolset [AgentDepsT ]] = []
1787
-
1788
- for toolset_function in self ._toolset_functions :
1789
- toolset = toolset_function (run_context )
1790
- if inspect .isawaitable (toolset ):
1791
- materialized_toolsets .append (await toolset )
1792
- else :
1793
- materialized_toolsets .append (toolset )
1794
-
1795
- return materialized_toolsets
1796
-
1797
1803
@staticmethod
1798
1804
def is_model_request_node (
1799
1805
node : _agent_graph .AgentNode [T , S ] | End [result .FinalResult [S ]],
0 commit comments