Skip to content

Commit 226830b

Browse files
committed
Add optional id field to toolsets
1 parent 01c550c commit 226830b

File tree

17 files changed

+243
-32
lines changed

17 files changed

+243
-32
lines changed

docs/tools.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ from pydantic_ai.ext.langchain import LangChainToolset
770770

771771

772772
toolkit = SlackToolkit()
773-
toolset = LangChainToolset(toolkit.get_tools())
773+
toolset = LangChainToolset(toolkit.get_tools(), id='slack')
774774

775775
agent = Agent('openai:gpt-4o', toolsets=[toolset])
776776
# ...
@@ -823,6 +823,7 @@ toolset = ACIToolset(
823823
'OPEN_WEATHER_MAP__FORECAST',
824824
],
825825
linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'),
826+
id='open_weather_map',
826827
)
827828

828829
agent = Agent('openai:gpt-4o', toolsets=[toolset])

docs/toolsets.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,10 @@ def temperature_fahrenheit(city: str) -> float:
8484
return 69.8
8585

8686

87-
weather_toolset = FunctionToolset(tools=[temperature_celsius, temperature_fahrenheit])
87+
weather_toolset = FunctionToolset(
88+
tools=[temperature_celsius, temperature_fahrenheit],
89+
id='weather', # (1)!
90+
)
8891

8992

9093
@weather_toolset.tool
@@ -95,10 +98,10 @@ def conditions(ctx: RunContext, city: str) -> str:
9598
return "It's raining"
9699

97100

98-
datetime_toolset = FunctionToolset()
101+
datetime_toolset = FunctionToolset(id='datetime')
99102
datetime_toolset.add_function(lambda: datetime.now(), name='now')
100103

101-
test_model = TestModel() # (1)!
104+
test_model = TestModel() # (2)!
102105
agent = Agent(test_model)
103106

104107
result = agent.run_sync('What tools are available?', toolsets=[weather_toolset])
@@ -110,7 +113,8 @@ print([t.name for t in test_model.last_model_request_parameters.function_tools])
110113
#> ['now']
111114
```
112115

113-
1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run.
116+
1. `FunctionToolset` supports an optional `id` argument that can help to identify the toolset in error messages. A toolset also needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow.
117+
2. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run.
114118

115119
_(This example is complete, it can be run "as is")_
116120

@@ -609,7 +613,7 @@ from pydantic_ai.ext.langchain import LangChainToolset
609613

610614

611615
toolkit = SlackToolkit()
612-
toolset = LangChainToolset(toolkit.get_tools())
616+
toolset = LangChainToolset(toolkit.get_tools(), id='slack')
613617

614618
agent = Agent('openai:gpt-4o', toolsets=[toolset])
615619
# ...
@@ -634,6 +638,7 @@ toolset = ACIToolset(
634638
'OPEN_WEATHER_MAP__FORECAST',
635639
],
636640
linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'),
641+
id='open_weather_map',
637642
)
638643

639644
agent = Agent('openai:gpt-4o', toolsets=[toolset])

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,10 @@ def __init__(
961961
self.max_retries = max_retries
962962
self.output_validators = output_validators or []
963963

964+
@property
965+
def id(self) -> str | None:
966+
return 'output'
967+
964968
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
965969
return {
966970
tool_def.name: ToolsetTool(

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ async def run(
273273
parameters_json_schema=tool.parameters,
274274
)
275275
for tool in run_input.tools
276-
]
276+
],
277+
id='ag_ui_frontend',
277278
)
278279
toolsets = [*toolsets, toolset] if toolsets else [toolset]
279280

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def __init__(
420420
if self._output_toolset:
421421
self._output_toolset.max_retries = self._max_result_retries
422422

423-
self._function_toolset = FunctionToolset(tools, max_retries=retries)
423+
self._function_toolset = FunctionToolset(tools, max_retries=retries, id='agent')
424424
self._user_toolsets = toolsets or ()
425425

426426
self.history_processors = history_processors or []

pydantic_ai_slim/pydantic_ai/ext/aci.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,7 @@ def implementation(*args: Any, **kwargs: Any) -> str:
7171
class ACIToolset(FunctionToolset):
7272
"""A toolset that wraps ACI.dev tools."""
7373

74-
def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str):
75-
super().__init__([tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions])
74+
def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str, id: str | None = None):
75+
super().__init__(
76+
[tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions], id=id
77+
)

pydantic_ai_slim/pydantic_ai/ext/langchain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,5 @@ def proxy(*args: Any, **kwargs: Any) -> str:
6565
class LangChainToolset(FunctionToolset):
6666
"""A toolset that wraps LangChain tools."""
6767

68-
def __init__(self, tools: list[LangChainTool]):
69-
super().__init__([tool_from_langchain(tool) for tool in tools])
68+
def __init__(self, tools: list[LangChainTool], id: str | None = None):
69+
super().__init__([tool_from_langchain(tool) for tool in tools], id=id)

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 144 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,12 @@ class MCPServer(AbstractToolset[Any], ABC):
6161
timeout: float = 5
6262
process_tool_call: ProcessToolCallback | None = None
6363
allow_sampling: bool = True
64-
max_retries: int = 1
6564
sampling_model: models.Model | None = None
65+
max_retries: int = 1
6666
# } end of "abstract fields"
6767

68+
_id: str | None = field(init=False, default=None)
69+
6870
_enter_lock: Lock = field(compare=False)
6971
_running_count: int
7072
_exit_stack: AsyncExitStack | None
@@ -73,7 +75,29 @@ class MCPServer(AbstractToolset[Any], ABC):
7375
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
7476
_write_stream: MemoryObjectSendStream[SessionMessage]
7577

76-
def __post_init__(self):
78+
def __init__(
79+
self,
80+
tool_prefix: str | None = None,
81+
log_level: mcp_types.LoggingLevel | None = None,
82+
log_handler: LoggingFnT | None = None,
83+
timeout: float = 5,
84+
process_tool_call: ProcessToolCallback | None = None,
85+
allow_sampling: bool = True,
86+
sampling_model: models.Model | None = None,
87+
max_retries: int = 1,
88+
id: str | None = None,
89+
):
90+
self.tool_prefix = tool_prefix
91+
self.log_level = log_level
92+
self.log_handler = log_handler
93+
self.timeout = timeout
94+
self.process_tool_call = process_tool_call
95+
self.allow_sampling = allow_sampling
96+
self.sampling_model = sampling_model
97+
self.max_retries = max_retries
98+
99+
self._id = id or tool_prefix
100+
77101
self._enter_lock = Lock()
78102
self._running_count = 0
79103
self._exit_stack = None
@@ -93,7 +117,11 @@ async def client_streams(
93117
yield
94118

95119
@property
96-
def name(self) -> str:
120+
def id(self) -> str | None:
121+
return self._id
122+
123+
@property
124+
def label(self) -> str:
97125
return repr(self)
98126

99127
@property
@@ -294,7 +322,7 @@ def _map_tool_result_part(
294322
assert_never(part)
295323

296324

297-
@dataclass
325+
@dataclass(init=False)
298326
class MCPServerStdio(MCPServer):
299327
"""Runs an MCP server in a subprocess and communicates with it over stdin/stdout.
300328
@@ -378,11 +406,61 @@ async def main():
378406
allow_sampling: bool = True
379407
"""Whether to allow MCP sampling through this client."""
380408

409+
sampling_model: models.Model | None = None
410+
"""The model to use for sampling."""
411+
381412
max_retries: int = 1
382413
"""The maximum number of times to retry a tool call."""
383414

384-
sampling_model: models.Model | None = None
385-
"""The model to use for sampling."""
415+
def __init__(
416+
self,
417+
command: str,
418+
args: Sequence[str],
419+
env: dict[str, str] | None = None,
420+
cwd: str | Path | None = None,
421+
id: str | None = None,
422+
tool_prefix: str | None = None,
423+
log_level: mcp_types.LoggingLevel | None = None,
424+
log_handler: LoggingFnT | None = None,
425+
timeout: float = 5,
426+
process_tool_call: ProcessToolCallback | None = None,
427+
allow_sampling: bool = True,
428+
sampling_model: models.Model | None = None,
429+
max_retries: int = 1,
430+
):
431+
"""Build a new MCP server.
432+
433+
Args:
434+
command: The command to run.
435+
args: The arguments to pass to the command.
436+
env: The environment variables to set in the subprocess.
437+
cwd: The working directory to use when spawning the process.
438+
id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
439+
tool_prefix: A prefix to add to all tools that are registered with the server.
440+
log_level: The log level to set when connecting to the server, if any.
441+
log_handler: A handler for logging messages from the server.
442+
timeout: The timeout in seconds to wait for the client to initialize.
443+
process_tool_call: Hook to customize tool calling and optionally pass extra metadata.
444+
allow_sampling: Whether to allow MCP sampling through this client.
445+
sampling_model: The model to use for sampling.
446+
max_retries: The maximum number of times to retry a tool call.
447+
"""
448+
self.command = command
449+
self.args = args
450+
self.env = env
451+
self.cwd = cwd
452+
453+
super().__init__(
454+
tool_prefix,
455+
log_level,
456+
log_handler,
457+
timeout,
458+
process_tool_call,
459+
allow_sampling,
460+
sampling_model,
461+
max_retries,
462+
id,
463+
)
386464

387465
@asynccontextmanager
388466
async def client_streams(
@@ -398,7 +476,10 @@ async def client_streams(
398476
yield read_stream, write_stream
399477

400478
def __repr__(self) -> str:
401-
return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'
479+
if self.id:
480+
return f'{self.__class__.__name__} {self.id!r}'
481+
else:
482+
return f'{self.__class__.__name__}(command={self.command!r}, args={self.args!r})'
402483

403484

404485
@dataclass
@@ -479,11 +560,61 @@ class _MCPServerHTTP(MCPServer):
479560
allow_sampling: bool = True
480561
"""Whether to allow MCP sampling through this client."""
481562

563+
sampling_model: models.Model | None = None
564+
"""The model to use for sampling."""
565+
482566
max_retries: int = 1
483567
"""The maximum number of times to retry a tool call."""
484568

485-
sampling_model: models.Model | None = None
486-
"""The model to use for sampling."""
569+
def __init__(
570+
self,
571+
url: str,
572+
headers: dict[str, Any] | None = None,
573+
http_client: httpx.AsyncClient | None = None,
574+
sse_read_timeout: float = 5 * 60,
575+
id: str | None = None,
576+
tool_prefix: str | None = None,
577+
log_level: mcp_types.LoggingLevel | None = None,
578+
log_handler: LoggingFnT | None = None,
579+
timeout: float = 5,
580+
process_tool_call: ProcessToolCallback | None = None,
581+
allow_sampling: bool = True,
582+
sampling_model: models.Model | None = None,
583+
max_retries: int = 1,
584+
):
585+
"""Build a new MCP server.
586+
587+
Args:
588+
url: The URL of the endpoint on the MCP server.
589+
headers: Optional HTTP headers to be sent with each request to the endpoint.
590+
http_client: An `httpx.AsyncClient` to use with the endpoint.
591+
sse_read_timeout: Maximum time in seconds to wait for new SSE messages before timing out.
592+
id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
593+
tool_prefix: A prefix to add to all tools that are registered with the server.
594+
log_level: The log level to set when connecting to the server, if any.
595+
log_handler: A handler for logging messages from the server.
596+
timeout: The timeout in seconds to wait for the client to initialize.
597+
process_tool_call: Hook to customize tool calling and optionally pass extra metadata.
598+
allow_sampling: Whether to allow MCP sampling through this client.
599+
sampling_model: The model to use for sampling.
600+
max_retries: The maximum number of times to retry a tool call.
601+
"""
602+
self.url = url
603+
self.headers = headers
604+
self.http_client = http_client
605+
self.sse_read_timeout = sse_read_timeout
606+
607+
super().__init__(
608+
tool_prefix,
609+
log_level,
610+
log_handler,
611+
timeout,
612+
process_tool_call,
613+
allow_sampling,
614+
sampling_model,
615+
max_retries,
616+
id,
617+
)
487618

488619
@property
489620
@abstractmethod
@@ -546,7 +677,10 @@ def httpx_client_factory(
546677
yield read_stream, write_stream
547678

548679
def __repr__(self) -> str: # pragma: no cover
549-
return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
680+
if self.id:
681+
return f'{self.__class__.__name__} {self.id!r}'
682+
else:
683+
return f'{self.__class__.__name__}(url={self.url!r})'
550684

551685

552686
@dataclass

pydantic_ai_slim/pydantic_ai/toolsets/abstract.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,23 @@ class AbstractToolset(ABC, Generic[AgentDepsT]):
7070
"""
7171

7272
@property
73-
def name(self) -> str:
73+
@abstractmethod
74+
def id(self) -> str | None:
75+
"""An ID for the toolset that is unique among all toolsets registered with the same agent.
76+
77+
If you're implementing a concrete implementation that users can instantiate more than once, you should let them optionally pass a custom ID to the constructor and return that here.
78+
79+
A toolset needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow.
80+
"""
81+
raise NotImplementedError()
82+
83+
@property
84+
def label(self) -> str:
7485
"""The name of the toolset for use in error messages."""
75-
return self.__class__.__name__.replace('Toolset', ' toolset')
86+
label = self.__class__.__name__
87+
if self.id:
88+
label += f' {self.id!r}'
89+
return label
7690

7791
@property
7892
def tool_name_conflict_hint(self) -> str:

pydantic_ai_slim/pydantic_ai/toolsets/combined.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ def __post_init__(self):
4040
self._entered_count = 0
4141
self._exit_stack = None
4242

43+
@property
44+
def id(self) -> str | None:
45+
return None
46+
47+
@property
48+
def label(self) -> str:
49+
return f'{self.__class__.__name__}({", ".join(toolset.label for toolset in self.toolsets)})'
50+
4351
async def __aenter__(self) -> Self:
4452
async with self._enter_lock:
4553
if self._entered_count == 0:
@@ -64,7 +72,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
6472
for name, tool in tools.items():
6573
if existing_tools := all_tools.get(name):
6674
raise UserError(
67-
f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}'
75+
f'{toolset.label} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.label}: {name!r}. {toolset.tool_name_conflict_hint}'
6876
)
6977

7078
all_tools[name] = _CombinedToolsetTool(

0 commit comments

Comments
 (0)