Skip to content

Commit 6770ee4

Browse files
committed
Add Dynamic Toolset Decorator
1 parent b785a0f commit 6770ee4

File tree

7 files changed

+103
-167
lines changed

7 files changed

+103
-167
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
ToolPrepareFunc,
5252
ToolsPrepareFunc,
5353
)
54-
from .toolsets import AbstractToolset
54+
from .toolsets import AbstractToolset, ToolsetFunc
5555
from .toolsets.combined import CombinedToolset
5656
from .toolsets.function import FunctionToolset
5757
from .toolsets.prepared import PreparedToolset
@@ -165,6 +165,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
165165
_function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False)
166166
_output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False)
167167
_user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False)
168+
_toolset_functions: Sequence[ToolsetFunc[AgentDepsT]] = dataclasses.field(repr=False)
168169
_prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
169170
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
170171
_max_result_retries: int = dataclasses.field(repr=False)
@@ -192,7 +193,7 @@ def __init__(
192193
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
193194
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
194195
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
195-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
196+
toolsets: Sequence[AbstractToolset[AgentDepsT] | ToolsetFunc[AgentDepsT]] | None = None,
196197
defer_model_check: bool = False,
197198
end_strategy: EndStrategy = 'early',
198199
instrument: InstrumentationSettings | bool | None = None,
@@ -223,7 +224,7 @@ def __init__(
223224
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
224225
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
225226
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
226-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
227+
toolsets: Sequence[AbstractToolset[AgentDepsT] | ToolsetFunc[AgentDepsT]] | None = None,
227228
defer_model_check: bool = False,
228229
end_strategy: EndStrategy = 'early',
229230
instrument: InstrumentationSettings | bool | None = None,
@@ -278,7 +279,7 @@ def __init__(
278279
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
279280
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
280281
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
281-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
282+
toolsets: Sequence[AbstractToolset[AgentDepsT] | ToolsetFunc[AgentDepsT]] | None = None,
282283
defer_model_check: bool = False,
283284
end_strategy: EndStrategy = 'early',
284285
instrument: InstrumentationSettings | bool | None = None,
@@ -421,7 +422,8 @@ def __init__(
421422
self._output_toolset.max_retries = self._max_result_retries
422423

423424
self._function_toolset = FunctionToolset(tools, max_retries=retries)
424-
self._user_toolsets = toolsets or ()
425+
self._user_toolsets = [toolset for toolset in toolsets or [] if isinstance(toolset, AbstractToolset)]
426+
self._toolset_functions = [toolset for toolset in toolsets or [] if not isinstance(toolset, AbstractToolset)]
425427

426428
self.history_processors = history_processors or []
427429

@@ -772,7 +774,11 @@ async def main():
772774
run_step=state.run_step,
773775
)
774776

775-
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
777+
toolset = self._get_toolset(
778+
output_toolset=output_toolset,
779+
additional_toolsets=[*(toolsets or []), *[func(run_context) for func in self._toolset_functions]],
780+
)
781+
776782
# This will raise errors for any name conflicts
777783
async with toolset:
778784
run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context)
@@ -1624,6 +1630,49 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams
16241630

16251631
return tool_decorator if func is None else tool_decorator(func)
16261632

1633+
def toolset(
1634+
self,
1635+
func: ToolsetFunc[AgentDepsT] | None = None,
1636+
/,
1637+
) -> Callable[[ToolsetFunc[AgentDepsT]], ToolsetFunc[AgentDepsT]] | ToolsetFunc[AgentDepsT]:
1638+
"""Decorator to register a toolset function.
1639+
1640+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
1641+
Can decorate a sync or async functions.
1642+
1643+
The decorator can be used bare (`agent.toolset`).
1644+
1645+
Overloads for every possible signature of `toolset` are included so the decorator doesn't obscure
1646+
the type of the function.
1647+
1648+
Example:
1649+
```python
1650+
from pydantic_ai import Agent, RunContext
1651+
1652+
agent = Agent('test', deps_type=str)
1653+
1654+
@agent.toolset
1655+
def simple_toolset(ctx: RunContext[str]) -> AbstractToolset[str]:
1656+
return FunctionToolset(foobar)
1657+
1658+
@agent.toolset
1659+
async def async_toolset(ctx: RunContext[str]) -> AbstractToolset[str]:
1660+
return FunctionToolset(foobar)
1661+
```
1662+
"""
1663+
if func is None:
1664+
1665+
def decorator(
1666+
func_: ToolsetFunc[AgentDepsT],
1667+
) -> ToolsetFunc[AgentDepsT]:
1668+
self._toolset_functions = [*self._toolset_functions, func_]
1669+
return func_
1670+
1671+
return decorator
1672+
else:
1673+
self._toolset_functions = [*self._toolset_functions, func]
1674+
return func
1675+
16271676
def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model:
16281677
"""Create a model configured for this agent.
16291678
@@ -1672,7 +1721,7 @@ def _get_toolset(
16721721
self,
16731722
output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET,
16741723
additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
1675-
) -> AbstractToolset[AgentDepsT]:
1724+
) -> CombinedToolset[AgentDepsT]:
16761725
"""Get the complete toolset.
16771726
16781727
Args:

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ async def turn_on_strict_if_openai(
118118
Usage `ToolsPrepareFunc[AgentDepsT]`.
119119
"""
120120

121-
122121
DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
123122
"""Supported docstring formats.
124123

pydantic_ai_slim/pydantic_ai/toolsets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .abstract import AbstractToolset, ToolsetTool
1+
from .abstract import AbstractToolset, ToolsetFunc, ToolsetTool
22
from .combined import CombinedToolset
33
from .deferred import DeferredToolset
44
from .filtered import FilteredToolset
@@ -10,6 +10,7 @@
1010

1111
__all__ = (
1212
'AbstractToolset',
13+
'ToolsetFunc',
1314
'ToolsetTool',
1415
'CombinedToolset',
1516
'DeferredToolset',

pydantic_ai_slim/pydantic_ai/toolsets/abstract.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol
5+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, TypeAlias
66

77
from pydantic_core import SchemaValidator
88
from typing_extensions import Self
@@ -17,6 +17,10 @@
1717
from .renamed import RenamedToolset
1818

1919

20+
ToolsetFunc: TypeAlias = 'Callable[[RunContext[AgentDepsT]], AbstractToolset[AgentDepsT]]'
21+
"""Definition of a function that returns a toolset based on the run context."""
22+
23+
2024
class SchemaValidatorProt(Protocol):
2125
"""Protocol for a Pydantic Core `SchemaValidator` or `PluggableSchemaValidator` (which is private but API-compatible)."""
2226

pydantic_ai_slim/pydantic_ai/toolsets/dynamic.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

tests/test_agent.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from pydantic_ai.profiles import ModelProfile
4747
from pydantic_ai.result import Usage
4848
from pydantic_ai.tools import ToolDefinition
49+
from pydantic_ai.toolsets.abstract import AbstractToolset
4950
from pydantic_ai.toolsets.combined import CombinedToolset
5051
from pydantic_ai.toolsets.function import FunctionToolset
5152
from pydantic_ai.toolsets.prefixed import PrefixedToolset
@@ -3701,6 +3702,44 @@ def bar() -> str:
37013702
assert result.output == snapshot('{"baz":"Hello from baz"}')
37023703

37033704

3705+
def test_toolset_decorator():
3706+
toolset = FunctionToolset()
3707+
3708+
@toolset.tool
3709+
def foo() -> str:
3710+
return 'Hello from foo'
3711+
3712+
available_tools: list[str] = []
3713+
3714+
async def prepare_tools(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]:
3715+
nonlocal available_tools
3716+
available_tools = [tool_def.name for tool_def in tool_defs]
3717+
return tool_defs
3718+
3719+
toolset_creation_count = 0
3720+
3721+
def create_function_toolset(ctx: RunContext[None]) -> AbstractToolset[None]:
3722+
nonlocal toolset_creation_count
3723+
toolset_creation_count += 1
3724+
return toolset
3725+
3726+
def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
3727+
if len(messages) == 1:
3728+
return ModelResponse(parts=[ToolCallPart('foo')])
3729+
elif len(messages) == 3:
3730+
return ModelResponse(parts=[ToolCallPart('foo')])
3731+
else:
3732+
return ModelResponse(parts=[TextPart('Done')])
3733+
3734+
agent = Agent(FunctionModel(respond), toolsets=[create_function_toolset], prepare_tools=prepare_tools)
3735+
3736+
run_result = agent.run_sync('Hello')
3737+
3738+
assert run_result._state.run_step == 3 # pyright: ignore[reportPrivateUsage]
3739+
assert len(available_tools) == 1
3740+
assert toolset_creation_count == 1
3741+
3742+
37043743
def test_adding_tools_during_run():
37053744
toolset = FunctionToolset()
37063745

tests/test_toolsets.py

Lines changed: 1 addition & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
from collections import defaultdict
55
from dataclasses import dataclass, replace
6-
from typing import Literal, TypeVar
6+
from typing import TypeVar
77
from unittest.mock import AsyncMock
88

99
import pytest
@@ -15,9 +15,7 @@
1515
from pydantic_ai.messages import ToolCallPart
1616
from pydantic_ai.models.test import TestModel
1717
from pydantic_ai.tools import ToolDefinition
18-
from pydantic_ai.toolsets.abstract import AbstractToolset
1918
from pydantic_ai.toolsets.combined import CombinedToolset
20-
from pydantic_ai.toolsets.dynamic import DynamicToolset
2119
from pydantic_ai.toolsets.filtered import FilteredToolset
2220
from pydantic_ai.toolsets.function import FunctionToolset
2321
from pydantic_ai.toolsets.prefixed import PrefixedToolset
@@ -628,97 +626,3 @@ def tool_c(x: int) -> int:
628626

629627
assert new_tool_manager.ctx.retries == {'tool_a': 1, 'tool_b': 1}
630628
assert new_tool_manager.failed_tools == set() # reset for new run step
631-
632-
633-
async def test_dynamic_toolset():
634-
"""Test that the dynamic toolset correctly handles nested context managers."""
635-
636-
something_context: RunContext[str] = build_run_context(deps='something')
637-
something_else_context: RunContext[str] = build_run_context(deps='something_else')
638-
nothing_context: RunContext[str] = build_run_context(deps='nothing')
639-
640-
something_toolset = FunctionToolset[str]()
641-
642-
something_else_toolset = FunctionToolset[str]()
643-
644-
nothing_toolset = FunctionToolset[str]()
645-
646-
async def prepare_toolset(ctx: RunContext[str]) -> AbstractToolset[str]:
647-
if ctx.deps == 'something':
648-
return something_toolset
649-
elif ctx.deps == 'something_else':
650-
return something_else_toolset
651-
else:
652-
return nothing_toolset
653-
654-
dynamic_toolset: DynamicToolset[str] = DynamicToolset[str](build_toolset_fn=prepare_toolset)
655-
656-
# Enter the first context manager
657-
async with dynamic_toolset:
658-
assert dynamic_toolset.toolset is None
659-
660-
# The toolset is built dynamically on the first call to get_tools within the context
661-
_ = await dynamic_toolset.get_tools(something_context)
662-
assert dynamic_toolset.toolset == something_toolset
663-
assert dynamic_toolset._toolset_stack.get() == [something_toolset] # pyright: ignore[reportPrivateUsage]
664-
665-
# Enter the second context manager
666-
async with dynamic_toolset:
667-
# The toolset appears empty, and is built on the call to get_tools
668-
assert dynamic_toolset.toolset is None
669-
_ = await dynamic_toolset.get_tools(nothing_context)
670-
assert dynamic_toolset.toolset == nothing_toolset
671-
assert dynamic_toolset._toolset_stack.get() == [something_toolset, nothing_toolset] # pyright: ignore[reportPrivateUsage]
672-
673-
# Enter the third context manager
674-
async with dynamic_toolset:
675-
# The toolset appears empty, and is built on the call to get_tools
676-
assert dynamic_toolset.toolset is None
677-
_ = await dynamic_toolset.get_tools(something_else_context)
678-
assert dynamic_toolset.toolset == something_else_toolset
679-
assert dynamic_toolset._toolset_stack.get() == [ # pyright: ignore[reportPrivateUsage]
680-
something_toolset,
681-
nothing_toolset,
682-
something_else_toolset,
683-
]
684-
685-
# Ensure the toolset reverts to the 2nd toolset
686-
_ = await dynamic_toolset.get_tools(nothing_context)
687-
assert dynamic_toolset.toolset == nothing_toolset
688-
assert dynamic_toolset._toolset_stack.get() == [something_toolset, nothing_toolset] # pyright: ignore[reportPrivateUsage]
689-
690-
# Ensure the toolset reverts to the 1st toolset
691-
assert dynamic_toolset.toolset == something_toolset
692-
assert dynamic_toolset._toolset_stack.get() == [something_toolset] # pyright: ignore[reportPrivateUsage]
693-
694-
# Ensure the toolset is empty after exiting all context managers
695-
async with dynamic_toolset:
696-
assert dynamic_toolset.toolset is None
697-
assert dynamic_toolset._toolset_stack.get() == [None] # pyright: ignore[reportPrivateUsage]
698-
699-
700-
async def test_dynamic_toolset_call():
701-
"""Test that the dynamic toolset correctly handles nested context managers."""
702-
703-
something_context: RunContext[str] = build_run_context(deps='something')
704-
705-
def test_something(ctx: RunContext[str]) -> Literal['something']:
706-
return 'something'
707-
708-
async def prepare_toolset(ctx: RunContext[str]) -> AbstractToolset[str]:
709-
toolset = FunctionToolset[str]()
710-
toolset.add_function(test_something)
711-
return toolset
712-
713-
dynamic_toolset: DynamicToolset[str] = DynamicToolset[str](build_toolset_fn=prepare_toolset)
714-
715-
# Enter the first context manager
716-
async with dynamic_toolset:
717-
# The toolset is built dynamically on the first call to get_tools within the context
718-
tools = await dynamic_toolset.get_tools(something_context)
719-
720-
first_tool = tools['test_something']
721-
first_tool_result = await dynamic_toolset.call_tool(
722-
name='test_something', tool_args={}, ctx=something_context, tool=first_tool
723-
)
724-
assert first_tool_result == 'something'

0 commit comments

Comments
 (0)