Skip to content

Commit c9852dc

Browse files
committed
Let toolset factory be registered per run step or for entire run
1 parent 62a149e commit c9852dc

File tree

4 files changed

+61
-76
lines changed

4 files changed

+61
-76
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pydantic.json_schema import GenerateJsonSchema
1717
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
1818

19+
from pydantic_ai.toolsets._dynamic import DynamicToolset
1920
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
2021
from pydantic_graph._utils import get_event_loop
2122

@@ -164,8 +165,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
164165
)
165166
_function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False)
166167
_output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False)
167-
_user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False)
168-
_toolset_functions: Sequence[ToolsetFunc[AgentDepsT]] = dataclasses.field(repr=False)
168+
_user_toolsets: list[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False)
169169
_prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
170170
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
171171
_max_result_retries: int = dataclasses.field(repr=False)
@@ -422,8 +422,9 @@ def __init__(
422422
self._output_toolset.max_retries = self._max_result_retries
423423

424424
self._function_toolset = FunctionToolset(tools, max_retries=retries)
425-
self._user_toolsets = [toolset for toolset in toolsets or [] if isinstance(toolset, AbstractToolset)]
426-
self._toolset_functions = [toolset for toolset in toolsets or [] if not isinstance(toolset, AbstractToolset)]
425+
self._user_toolsets = [
426+
toolset if isinstance(toolset, AbstractToolset) else DynamicToolset(toolset) for toolset in toolsets or []
427+
]
427428

428429
self.history_processors = history_processors or []
429430

@@ -774,11 +775,9 @@ async def main():
774775
run_step=state.run_step,
775776
)
776777

777-
toolsets_from_functions = await self._materialize_toolset_functions(run_context)
778-
779778
toolset = self._get_toolset(
780779
output_toolset=output_toolset,
781-
additional_toolsets=[*(toolsets or []), *toolsets_from_functions],
780+
additional_toolsets=toolsets,
782781
)
783782

784783
# This will raise errors for any name conflicts
@@ -1632,11 +1631,24 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams
16321631

16331632
return tool_decorator if func is None else tool_decorator(func)
16341633

1634+
@overload
1635+
def toolset(self, func: ToolsetFunc[AgentDepsT], /) -> ToolsetFunc[AgentDepsT]: ...
1636+
1637+
@overload
16351638
def toolset(
16361639
self,
1637-
func: ToolsetFunc[AgentDepsT],
16381640
/,
1639-
) -> Callable[[ToolsetFunc[AgentDepsT]], ToolsetFunc[AgentDepsT]] | ToolsetFunc[AgentDepsT]:
1641+
*,
1642+
per_run_step: bool = True,
1643+
) -> Callable[[ToolsetFunc[AgentDepsT]], ToolsetFunc[AgentDepsT]]: ...
1644+
1645+
def toolset(
1646+
self,
1647+
func: ToolsetFunc[AgentDepsT] | None = None,
1648+
/,
1649+
*,
1650+
per_run_step: bool = True,
1651+
) -> Any:
16401652
"""Decorator to register a toolset function.
16411653
16421654
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
@@ -1656,9 +1668,17 @@ async def simple_toolset(ctx: RunContext[str]) -> AbstractToolset[str]:
16561668
return FunctionToolset()
16571669
16581670
```
1671+
1672+
Args:
1673+
func: The toolset function to register.
1674+
per_run_step: Whether to re-evaluate the toolset for each run step. Defaults to True.
16591675
"""
1660-
self._toolset_functions = [*self._toolset_functions, func]
1661-
return func
1676+
1677+
def toolset_decorator(func_: ToolsetFunc[AgentDepsT]) -> ToolsetFunc[AgentDepsT]:
1678+
self._user_toolsets.append(DynamicToolset(func_, per_run_step=per_run_step))
1679+
return func_
1680+
1681+
return toolset_decorator if func is None else toolset_decorator(func)
16621682

16631683
def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model:
16641684
"""Create a model configured for this agent.
@@ -1780,20 +1800,6 @@ def _prepare_output_schema(
17801800

17811801
return schema # pyright: ignore[reportReturnType]
17821802

1783-
async def _materialize_toolset_functions(
1784-
self, run_context: RunContext[AgentDepsT]
1785-
) -> list[AbstractToolset[AgentDepsT]]:
1786-
materialized_toolsets: list[AbstractToolset[AgentDepsT]] = []
1787-
1788-
for toolset_function in self._toolset_functions:
1789-
toolset = toolset_function(run_context)
1790-
if inspect.isawaitable(toolset):
1791-
materialized_toolsets.append(await toolset)
1792-
else:
1793-
materialized_toolsets.append(toolset)
1794-
1795-
return materialized_toolsets
1796-
17971803
@staticmethod
17981804
def is_model_request_node(
17991805
node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],

pydantic_ai_slim/pydantic_ai/toolsets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from .abstract import AbstractToolset, ToolsetFunc, ToolsetTool
1+
from ._dynamic import ToolsetFunc
2+
from .abstract import AbstractToolset, ToolsetTool
23
from .combined import CombinedToolset
34
from .deferred import DeferredToolset
45
from .filtered import FilteredToolset

pydantic_ai_slim/pydantic_ai/toolsets/abstract.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from collections.abc import Awaitable
54
from dataclasses import dataclass
65
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol
76

87
from pydantic_core import SchemaValidator
9-
from typing_extensions import Self, TypeAlias
8+
from typing_extensions import Self
109

1110
from .._run_context import AgentDepsT, RunContext
1211
from ..tools import ToolDefinition, ToolsPrepareFunc
@@ -17,11 +16,6 @@
1716
from .prepared import PreparedToolset
1817
from .renamed import RenamedToolset
1918

20-
ToolsetFunc: TypeAlias = (
21-
'Callable[[RunContext[AgentDepsT]], AbstractToolset[AgentDepsT] | Awaitable[AbstractToolset[AgentDepsT]]]'
22-
)
23-
"""An sync/async function which takes a run context and returns a toolset."""
24-
2519

2620
class SchemaValidatorProt(Protocol):
2721
"""Protocol for a Pydantic Core `SchemaValidator` or `PluggableSchemaValidator` (which is private but API-compatible)."""

tests/test_agent.py

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import re
33
import sys
4+
from collections import defaultdict
45
from dataclasses import dataclass
56
from datetime import timezone
67
from typing import Any, Callable, Union
@@ -3716,64 +3717,47 @@ async def prepare_tools(ctx: RunContext[None], tool_defs: list[ToolDefinition])
37163717
available_tools = [tool_def.name for tool_def in tool_defs]
37173718
return tool_defs
37183719

3719-
toolset_creation_count = 0
3720+
toolset_creation_counts: dict[str, int] = defaultdict(int)
37203721

3721-
def create_function_toolset(ctx: RunContext[None]) -> AbstractToolset[None]:
3722-
nonlocal toolset_creation_count
3723-
toolset_creation_count += 1
3724-
return toolset
3722+
def via_toolsets_arg(ctx: RunContext[None]) -> AbstractToolset[None]:
3723+
nonlocal toolset_creation_counts
3724+
toolset_creation_counts['via_toolsets_arg'] += 1
3725+
return toolset.prefixed('via_toolsets_arg')
37253726

37263727
def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
37273728
if len(messages) == 1:
3728-
return ModelResponse(parts=[ToolCallPart('foo')])
3729+
return ModelResponse(parts=[ToolCallPart('via_toolsets_arg_foo')])
37293730
elif len(messages) == 3:
3730-
return ModelResponse(parts=[ToolCallPart('foo')])
3731+
return ModelResponse(parts=[ToolCallPart('via_toolset_decorator_foo')])
37313732
else:
37323733
return ModelResponse(parts=[TextPart('Done')])
37333734

3734-
agent = Agent(FunctionModel(respond), toolsets=[create_function_toolset], prepare_tools=prepare_tools)
3735-
3736-
run_result = agent.run_sync('Hello')
3737-
3738-
assert run_result._state.run_step == 3 # pyright: ignore[reportPrivateUsage]
3739-
assert len(available_tools) == 1
3740-
assert toolset_creation_count == 1
3741-
3742-
3743-
async def test_toolset_decorator():
3744-
toolset = FunctionToolset()
3745-
3746-
agent: Agent[None, str] = Agent('test')
3735+
agent = Agent(FunctionModel(respond), toolsets=[via_toolsets_arg], prepare_tools=prepare_tools)
37473736

37483737
@agent.toolset
3749-
def create_function_toolset(ctx: RunContext[None]) -> AbstractToolset[None]:
3750-
return toolset
3751-
3752-
async def create_function_toolset_async(ctx: RunContext[None]) -> AbstractToolset[None]:
3753-
return toolset
3738+
def via_toolset_decorator(ctx: RunContext[None]) -> AbstractToolset[None]:
3739+
nonlocal toolset_creation_counts
3740+
toolset_creation_counts['via_toolset_decorator'] += 1
3741+
return toolset.prefixed('via_toolset_decorator')
37543742

3755-
agent.toolset(create_function_toolset_async)
3743+
@agent.toolset(per_run_step=False)
3744+
async def via_toolset_decorator_for_entire_run(ctx: RunContext[None]) -> AbstractToolset[None]:
3745+
nonlocal toolset_creation_counts
3746+
toolset_creation_counts['via_toolset_decorator_for_entire_run'] += 1
3747+
return toolset.prefixed('via_toolset_decorator_for_entire_run')
37563748

3757-
agent_toolset_functions = agent._toolset_functions # pyright: ignore[reportPrivateUsage]
3758-
3759-
assert len(agent_toolset_functions) == 2
3760-
assert agent_toolset_functions[0] is create_function_toolset
3761-
assert agent_toolset_functions[1] is create_function_toolset_async
3749+
run_result = agent.run_sync('Hello')
37623750

3763-
fake_run_context = RunContext(
3764-
deps=None,
3765-
model=TestModel(),
3766-
usage=Usage(),
3767-
prompt=None,
3768-
messages=[],
3769-
run_step=0,
3751+
assert run_result._state.run_step == 3 # pyright: ignore[reportPrivateUsage]
3752+
assert len(available_tools) == 3
3753+
assert toolset_creation_counts == snapshot(
3754+
{
3755+
'via_toolsets_arg': 4,
3756+
'via_toolset_decorator': 4,
3757+
'via_toolset_decorator_for_entire_run': 1,
3758+
}
37703759
)
37713760

3772-
toolsets = await agent._materialize_toolset_functions(run_context=fake_run_context) # pyright: ignore[reportPrivateUsage]
3773-
assert len(toolsets) == 2
3774-
assert isinstance(toolsets[0], AbstractToolset)
3775-
assert isinstance(toolsets[1], AbstractToolset)
3776-
37773761

37783762
def test_adding_tools_during_run():
37793763
toolset = FunctionToolset()

0 commit comments

Comments
 (0)