Skip to content

Commit 26b8364

Browse files
committed
Fix tests
1 parent 0782194 commit 26b8364

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,12 +1644,10 @@ def toolset(
16441644
16451645
The decorator can be used bare (`agent.toolset`).
16461646
1647-
Overloads for every possible signature of `toolset` are included so the decorator doesn't obscure
1648-
the type of the function.
1649-
16501647
Example:
16511648
```python
16521649
from pydantic_ai import Agent, RunContext
1650+
from pydantic_ai.toolsets import AbstractToolset, FunctionToolset
16531651
16541652
agent = Agent('test', deps_type=str)
16551653
@@ -1792,7 +1790,9 @@ def _prepare_output_schema(
17921790

17931791
return schema # pyright: ignore[reportReturnType]
17941792

1795-
async def _materialize_toolset_functions(self, run_context: RunContext[AgentDepsT]) -> list[AbstractToolset[AgentDepsT]]:
1793+
async def _materialize_toolset_functions(
1794+
self, run_context: RunContext[AgentDepsT]
1795+
) -> list[AbstractToolset[AgentDepsT]]:
17961796
materialized_toolsets: list[AbstractToolset[AgentDepsT]] = []
17971797

17981798
for toolset_function in self._toolset_functions:
@@ -2356,7 +2356,6 @@ def _traceparent(self, *, required: bool = True) -> str | None:
23562356
def data(self) -> OutputDataT:
23572357
return self.output
23582358

2359-
23602359
def _set_output_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
23612360
"""Set return content for the output tool.
23622361

pydantic_ai_slim/pydantic_ai/toolsets/abstract.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4+
from collections.abc import Awaitable
45
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generic, Literal, Protocol
6+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol
67

78
from pydantic_core import SchemaValidator
89
from typing_extensions import Self, TypeAlias
@@ -16,7 +17,9 @@
1617
from .prepared import PreparedToolset
1718
from .renamed import RenamedToolset
1819

19-
ToolsetFunc: TypeAlias = 'Callable[[RunContext[AgentDepsT]], AbstractToolset[AgentDepsT] | Awaitable[AbstractToolset[AgentDepsT]]]'
20+
ToolsetFunc: TypeAlias = (
21+
'Callable[[RunContext[AgentDepsT]], AbstractToolset[AgentDepsT] | Awaitable[AbstractToolset[AgentDepsT]]]'
22+
)
2023
"""An sync/async function which takes a run context and returns a toolset."""
2124

2225

tests/test_agent.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from datetime import timezone
66
from typing import Any, Callable, Union
77

8+
from pydantic_ai._agent_graph import build_run_context
9+
810
import httpx
911
import pytest
1012
from dirty_equals import IsJson
@@ -3739,24 +3741,43 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
37393741
assert len(available_tools) == 1
37403742
assert toolset_creation_count == 1
37413743

3742-
def test_toolset_decorator():
3744+
3745+
async def test_toolset_decorator():
37433746
toolset = FunctionToolset()
37443747

37453748
@toolset.tool
37463749
def foo() -> str:
37473750
return 'Hello from foo'
37483751

3749-
37503752
agent = Agent('test')
37513753

37523754
@agent.toolset
37533755
def create_function_toolset(ctx: RunContext[None]) -> AbstractToolset[None]:
37543756
return toolset
37553757

3758+
def create_function_toolset_bare(ctx: RunContext[None]) -> AbstractToolset[None]:
3759+
return toolset
3760+
3761+
agent.toolset(create_function_toolset_bare)
3762+
37563763
agent_toolset_functions = agent._toolset_functions # pyright: ignore[reportPrivateUsage]
37573764

3758-
assert len(agent_toolset_functions) == 1
3765+
assert len(agent_toolset_functions) == 2
37593766
assert agent_toolset_functions[0] is create_function_toolset
3767+
assert agent_toolset_functions[1] is create_function_toolset_bare
3768+
3769+
fake_run_context = RunContext(
3770+
deps=None,
3771+
model=TestModel(),
3772+
usage=Usage(),
3773+
prompt=None,
3774+
messages=[],
3775+
run_step=0,
3776+
)
3777+
3778+
toolsets = await agent._materialize_toolset_functions(run_context=fake_run_context) # pyright: ignore[reportPrivateUsage]
3779+
assert len(toolsets) == 2
3780+
37603781

37613782
def test_adding_tools_during_run():
37623783
toolset = FunctionToolset()

0 commit comments

Comments
 (0)