Skip to content

Commit eb5e118

Browse files
committed
add dynamic toolset example
1 parent 41dd069 commit eb5e118

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC
4+
from collections.abc import Awaitable, Callable
5+
from contextvars import ContextVar, Token
6+
from typing import TYPE_CHECKING, Any, Self
7+
8+
from pydantic_ai.tools import AgentDepsT, RunContext
9+
from pydantic_ai.toolsets import AbstractToolset
10+
from pydantic_ai.toolsets.abstract import ToolsetTool
11+
from pydantic_ai.toolsets.combined import CombinedToolset
12+
13+
if TYPE_CHECKING:
14+
pass
15+
16+
17+
BuildToolsetFunc = Callable[[RunContext[AgentDepsT]], Awaitable[AbstractToolset[AgentDepsT]]]
18+
19+
20+
class DynamicToolset(AbstractToolset[AgentDepsT], ABC):
21+
"""A Toolset that is dynamically built during an Agent run based on the first available run context."""
22+
23+
_build_toolset_fn: BuildToolsetFunc[AgentDepsT]
24+
25+
_dynamic_toolset: ContextVar[CombinedToolset[AgentDepsT]] = ContextVar('_toolset', default=CombinedToolset[AgentDepsT](toolsets=[]))
26+
_token: Token[CombinedToolset[AgentDepsT]] | None = None
27+
#_toolset_deps: ContextVar[AgentDepsT | None] = ContextVar('_toolset_deps', default=None)
28+
29+
def __init__(self, build_toolset_fn: BuildToolsetFunc[AgentDepsT]):
30+
self._build_toolset_fn = build_toolset_fn
31+
32+
async def __aenter__(self) -> Self:
33+
# Store the current toolset in a token, so that it can be reset when the context is exited
34+
self._token = self._dynamic_toolset.set(CombinedToolset[AgentDepsT](toolsets=[]))
35+
return self
36+
37+
async def __aexit__(self, *args: Any) -> bool | None:
38+
# Reset the toolset to the previous toolset, so that it can be used again
39+
if self._token:
40+
self._dynamic_toolset.reset(self._token)
41+
self._token = None
42+
return None
43+
44+
@property
45+
def _toolset(self) -> CombinedToolset[AgentDepsT]:
46+
if not (toolset := self._dynamic_toolset.get()):
47+
msg = 'Toolset not initialized. Use the `async with` context manager to initialize the toolset.'
48+
raise RuntimeError(msg)
49+
50+
return toolset
51+
52+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
53+
if len(self._toolset.toolsets) == 0 or ctx.run_step == 0:
54+
toolset = await self._build_toolset_fn(ctx)
55+
self._toolset.toolsets = [toolset]
56+
57+
return await self._toolset.get_tools(ctx=ctx)
58+
59+
async def call_tool(
60+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
61+
) -> Any:
62+
return await self._toolset.call_tool(name=name, tool_args=tool_args, ctx=ctx, tool=tool)

tests/test_toolsets.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import re
44
from dataclasses import dataclass, replace
5-
from typing import TypeVar
5+
from pathlib import Path
6+
from typing import Literal, TypeVar
67

8+
from pydantic_ai.agent import Agent
79
import pytest
810
from inline_snapshot import snapshot
911

@@ -13,7 +15,9 @@
1315
from pydantic_ai.messages import ToolCallPart
1416
from pydantic_ai.models.test import TestModel
1517
from pydantic_ai.tools import ToolDefinition
18+
from pydantic_ai.toolsets.abstract import AbstractToolset
1619
from pydantic_ai.toolsets.combined import CombinedToolset
20+
from pydantic_ai.toolsets.dynamic import DynamicToolset
1721
from pydantic_ai.toolsets.filtered import FilteredToolset
1822
from pydantic_ai.toolsets.function import FunctionToolset
1923
from pydantic_ai.toolsets.prefixed import PrefixedToolset
@@ -469,3 +473,71 @@ async def test_context_manager():
469473
async with toolset:
470474
assert server1.is_running
471475
assert server2.is_running
476+
477+
478+
async def test_dynamic_toolset():
479+
run_context = build_run_context(Path())
480+
481+
def test_function(ctx: RunContext[Path]) -> Literal['nothing']:
482+
return 'nothing'
483+
484+
function_toolset = FunctionToolset[Path]()
485+
function_toolset.add_function(test_function)
486+
487+
async def prepare_toolset(ctx: RunContext[Path]) -> AbstractToolset[Path]:
488+
return function_toolset
489+
490+
dynamic_toolset: DynamicToolset[Path] = DynamicToolset[Path](build_toolset_fn=prepare_toolset)
491+
492+
# The toolset is unique per context manager
493+
async with dynamic_toolset:
494+
495+
# The toolset starts empty
496+
assert dynamic_toolset._dynamic_toolset.get().toolsets == []
497+
498+
# The toolset is built dynamically on the first call to get_tools within the context
499+
_ = await dynamic_toolset.get_tools(run_context)
500+
assert dynamic_toolset._dynamic_toolset.get().toolsets == [function_toolset]
501+
502+
# Any time the context is entered again, the toolsets are reset, to be generated again
503+
async with dynamic_toolset:
504+
assert dynamic_toolset._dynamic_toolset.get().toolsets == []
505+
506+
assert dynamic_toolset._dynamic_toolset.get().toolsets == [function_toolset]
507+
508+
async with dynamic_toolset:
509+
assert dynamic_toolset._dynamic_toolset.get().toolsets == []
510+
511+
async def test_dynamic_toolset_with_agent():
512+
run_context = build_run_context(Path())
513+
514+
def test_function(ctx: RunContext[Path]) -> Literal['nothing']:
515+
return 'nothing'
516+
517+
518+
def test_function_two(ctx: RunContext[Path]) -> Literal['nothing']:
519+
return 'nothing'
520+
521+
522+
function_toolset = FunctionToolset[Path]()
523+
function_toolset.add_function(test_function)
524+
function_toolset.add_function(test_function_two)
525+
526+
async def prepare_toolset(ctx: RunContext[Path]) -> AbstractToolset[Path]:
527+
return function_toolset
528+
529+
dynamic_toolset: DynamicToolset[Path] = DynamicToolset[Path](build_toolset_fn=prepare_toolset)
530+
531+
agent = Agent[Path, str](
532+
model=TestModel(),
533+
toolsets=[dynamic_toolset],
534+
deps_type=Path,
535+
output_type=str,
536+
)
537+
538+
async with agent:
539+
result = await agent.run(deps=Path("."), user_prompt="Please call each tool you have access to and tell me what it returns")
540+
print(result.output)
541+
542+
result = await agent.run(deps=Path("./tomato"), user_prompt="Please call each tool you have access to and tell me what it returns.")
543+
print(result.output)

0 commit comments

Comments
 (0)