Skip to content

Add name argument to FunctionToolset, DeferredToolset and MCPServer #2250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ from pydantic_ai.ext.langchain import LangChainToolset


toolkit = SlackToolkit()
toolset = LangChainToolset(toolkit.get_tools())
toolset = LangChainToolset('slack', toolkit.get_tools())

agent = Agent('openai:gpt-4o', toolsets=[toolset])
# ...
Expand Down Expand Up @@ -818,6 +818,7 @@ from pydantic_ai.ext.aci import ACIToolset


toolset = ACIToolset(
'open_weather_map',
[
'OPEN_WEATHER_MAP__CURRENT_WEATHER',
'OPEN_WEATHER_MAP__FORECAST',
Expand Down
19 changes: 10 additions & 9 deletions docs/toolsets.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def override_tool():
return "I override all other tools"


agent_toolset = FunctionToolset(tools=[agent_tool]) # (1)!
extra_toolset = FunctionToolset(tools=[extra_tool])
override_toolset = FunctionToolset(tools=[override_tool])
agent_toolset = FunctionToolset('agent', tools=[agent_tool]) # (1)!
extra_toolset = FunctionToolset('extra', tools=[extra_tool])
override_toolset = FunctionToolset('override', tools=[override_tool])

test_model = TestModel() # (2)!
agent = Agent(test_model, toolsets=[agent_toolset])
Expand Down Expand Up @@ -84,7 +84,7 @@ def temperature_fahrenheit(city: str) -> float:
return 69.8


weather_toolset = FunctionToolset(tools=[temperature_celsius, temperature_fahrenheit])
weather_toolset = FunctionToolset('weather', tools=[temperature_celsius, temperature_fahrenheit])


@weather_toolset.tool
Expand All @@ -95,7 +95,7 @@ def conditions(ctx: RunContext, city: str) -> str:
return "It's raining"


datetime_toolset = FunctionToolset()
datetime_toolset = FunctionToolset('datetime')
datetime_toolset.add_function(lambda: datetime.now(), name='now')

test_model = TestModel() # (1)!
Expand Down Expand Up @@ -417,7 +417,7 @@ test_model = TestModel() # (1)!
agent = Agent(
test_model,
deps_type=WrapperToolset, # (2)!
toolsets=[togglable_toolset, FunctionToolset([toggle])]
toolsets=[togglable_toolset, FunctionToolset('toggle', [toggle])]
)
result = agent.run_sync('Toggle the toolset', deps=togglable_toolset)
print([t.name for t in test_model.last_model_request_parameters.function_tools]) # (3)!
Expand Down Expand Up @@ -462,7 +462,7 @@ from pydantic import BaseModel
from pydantic_ai import Agent
from pydantic_ai.toolsets.function import FunctionToolset

toolset = FunctionToolset()
toolset = FunctionToolset('user_info')


@toolset.tool
Expand Down Expand Up @@ -502,7 +502,7 @@ from pydantic_ai.messages import ModelMessage
def run_agent(
messages: list[ModelMessage] = [], frontend_tools: list[ToolDefinition] = {}
) -> tuple[Union[PersonalizedGreeting, DeferredToolCalls], list[ModelMessage]]:
deferred_toolset = DeferredToolset(frontend_tools)
deferred_toolset = DeferredToolset('frontend', frontend_tools)
result = agent.run_sync(
toolsets=[deferred_toolset], # (1)!
output_type=[agent.output_type, DeferredToolCalls], # (2)!
Expand Down Expand Up @@ -609,7 +609,7 @@ from pydantic_ai.ext.langchain import LangChainToolset


toolkit = SlackToolkit()
toolset = LangChainToolset(toolkit.get_tools())
toolset = LangChainToolset('slack', toolkit.get_tools())

agent = Agent('openai:gpt-4o', toolsets=[toolset])
# ...
Expand All @@ -629,6 +629,7 @@ from pydantic_ai.ext.aci import ACIToolset


toolset = ACIToolset(
'open_weather_map',
[
'OPEN_WEATHER_MAP__CURRENT_WEATHER',
'OPEN_WEATHER_MAP__FORECAST',
Expand Down
5 changes: 3 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,15 @@ async def run(
# Pydantic AI events and actual AG-UI tool names, preventing the tool from being called. If any
# conflicts arise, the AG-UI tool should be renamed or a `PrefixedToolset` used for local toolsets.
toolset = DeferredToolset[AgentDepsT](
[
name='AG-UI frontend tools',
tool_defs=[
ToolDefinition(
name=tool.name,
description=tool.description,
parameters_json_schema=tool.parameters,
)
for tool in run_input.tools
]
],
)
toolsets = [*toolsets, toolset] if toolsets else [toolset]

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def __init__(
if self._output_toolset:
self._output_toolset.max_retries = self._max_result_retries

self._function_toolset = FunctionToolset(tools, max_retries=retries)
self._function_toolset = FunctionToolset('Agent tools', tools, max_retries=retries)
self._user_toolsets = toolsets or ()

self.history_processors = history_processors or []
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ext/aci.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,5 @@ def implementation(*args: Any, **kwargs: Any) -> str:
class ACIToolset(FunctionToolset):
"""A toolset that wraps ACI.dev tools."""

def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str):
super().__init__([tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions])
def __init__(self, name: str, aci_functions: Sequence[str], linked_account_owner_id: str):
super().__init__(name, [tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions])
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ext/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@ def proxy(*args: Any, **kwargs: Any) -> str:
class LangChainToolset(FunctionToolset):
"""A toolset that wraps LangChain tools."""

def __init__(self, tools: list[LangChainTool]):
super().__init__([tool_from_langchain(tool) for tool in tools])
def __init__(self, name: str, tools: list[LangChainTool]):
super().__init__(name, [tool_from_langchain(tool) for tool in tools])
143 changes: 132 additions & 11 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ class MCPServer(AbstractToolset[Any], ABC):
timeout: float = 5
process_tool_call: ProcessToolCallback | None = None
allow_sampling: bool = True
max_retries: int = 1
sampling_model: models.Model | None = None
max_retries: int = 1
# } end of "abstract fields"

_name: str

_enter_lock: Lock = field(compare=False)
_running_count: int
_exit_stack: AsyncExitStack | None
Expand All @@ -73,7 +75,26 @@ class MCPServer(AbstractToolset[Any], ABC):
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
_write_stream: MemoryObjectSendStream[SessionMessage]

def __post_init__(self):
def __init__(
self,
tool_prefix: str | None = None,
log_level: mcp_types.LoggingLevel | None = None,
log_handler: LoggingFnT | None = None,
timeout: float = 5,
process_tool_call: ProcessToolCallback | None = None,
allow_sampling: bool = True,
sampling_model: models.Model | None = None,
max_retries: int = 1,
):
self.tool_prefix = tool_prefix
self.log_level = log_level
self.log_handler = log_handler
self.timeout = timeout
self.process_tool_call = process_tool_call
self.allow_sampling = allow_sampling
self.sampling_model = sampling_model
self.max_retries = max_retries

self._enter_lock = Lock()
self._running_count = 0
self._exit_stack = None
Expand All @@ -94,7 +115,7 @@ async def client_streams(

@property
def name(self) -> str:
return repr(self)
return self._name

@property
def tool_name_conflict_hint(self) -> str:
Expand Down Expand Up @@ -294,7 +315,7 @@ def _map_tool_result_part(
assert_never(part)


@dataclass
@dataclass(init=False)
class MCPServerStdio(MCPServer):
"""Runs an MCP server in a subprocess and communicates with it over stdin/stdout.

Expand Down Expand Up @@ -378,11 +399,62 @@ async def main():
allow_sampling: bool = True
"""Whether to allow MCP sampling through this client."""

sampling_model: models.Model | None = None
"""The model to use for sampling."""

max_retries: int = 1
"""The maximum number of times to retry a tool call."""

sampling_model: models.Model | None = None
"""The model to use for sampling."""
def __init__(
self,
command: str,
args: Sequence[str],
env: dict[str, str] | None = None,
cwd: str | Path | None = None,
name: str | None = None,
tool_prefix: str | None = None,
log_level: mcp_types.LoggingLevel | None = None,
log_handler: LoggingFnT | None = None,
timeout: float = 5,
process_tool_call: ProcessToolCallback | None = None,
allow_sampling: bool = True,
sampling_model: models.Model | None = None,
max_retries: int = 1,
):
"""Build a new MCP server.

Args:
command: The command to run.
args: The arguments to pass to the command.
env: The environment variables to set in the subprocess.
cwd: The working directory to use when spawning the process.
name: The unique name of the MCP server.
tool_prefix: A prefix to add to all tools that are registered with the server.
log_level: The log level to set when connecting to the server, if any.
log_handler: A handler for logging messages from the server.
timeout: The timeout in seconds to wait for the client to initialize.
process_tool_call: Hook to customize tool calling and optionally pass extra metadata.
allow_sampling: Whether to allow MCP sampling through this client.
sampling_model: The model to use for sampling.
max_retries: The maximum number of times to retry a tool call.
"""
self.command = command
self.args = args
self.env = env
self.cwd = cwd

self._name = name or tool_prefix or ' '.join([command, *args])

super().__init__(
tool_prefix,
log_level,
log_handler,
timeout,
process_tool_call,
allow_sampling,
sampling_model,
max_retries,
)

@asynccontextmanager
async def client_streams(
Expand All @@ -401,7 +473,7 @@ def __repr__(self) -> str:
return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'


@dataclass
@dataclass(init=False)
class _MCPServerHTTP(MCPServer):
url: str
"""The URL of the endpoint on the MCP server."""
Expand Down Expand Up @@ -479,11 +551,62 @@ class _MCPServerHTTP(MCPServer):
allow_sampling: bool = True
"""Whether to allow MCP sampling through this client."""

sampling_model: models.Model | None = None
"""The model to use for sampling."""

max_retries: int = 1
"""The maximum number of times to retry a tool call."""

sampling_model: models.Model | None = None
"""The model to use for sampling."""
def __init__(
self,
url: str,
headers: dict[str, Any] | None = None,
http_client: httpx.AsyncClient | None = None,
sse_read_timeout: float = 5 * 60,
name: str | None = None,
tool_prefix: str | None = None,
log_level: mcp_types.LoggingLevel | None = None,
log_handler: LoggingFnT | None = None,
timeout: float = 5,
process_tool_call: ProcessToolCallback | None = None,
allow_sampling: bool = True,
sampling_model: models.Model | None = None,
max_retries: int = 1,
):
"""Build a new MCP server.

Args:
url: The URL of the endpoint on the MCP server.
headers: Optional HTTP headers to be sent with each request to the endpoint.
http_client: An `httpx.AsyncClient` to use with the endpoint.
sse_read_timeout: Maximum time in seconds to wait for new SSE messages before timing out.
name: The unique name of the MCP server.
tool_prefix: A prefix to add to all tools that are registered with the server.
log_level: The log level to set when connecting to the server, if any.
log_handler: A handler for logging messages from the server.
timeout: The timeout in seconds to wait for the client to initialize.
process_tool_call: Hook to customize tool calling and optionally pass extra metadata.
allow_sampling: Whether to allow MCP sampling through this client.
sampling_model: The model to use for sampling.
max_retries: The maximum number of times to retry a tool call.
"""
self.url = url
self.headers = headers
self.http_client = http_client
self.sse_read_timeout = sse_read_timeout

self._name = name or tool_prefix or url

super().__init__(
tool_prefix,
log_level,
log_handler,
timeout,
process_tool_call,
allow_sampling,
sampling_model,
max_retries,
)

@property
@abstractmethod
Expand Down Expand Up @@ -583,7 +706,6 @@ def _transport_client(self):


@deprecated('The `MCPServerHTTP` class is deprecated, use `MCPServerSSE` instead.')
@dataclass
class MCPServerHTTP(MCPServerSSE):
"""An MCP server that connects over HTTP using the old SSE transport.

Expand Down Expand Up @@ -612,7 +734,6 @@ async def main():
"""


@dataclass
class MCPServerStreamableHTTP(_MCPServerHTTP):
"""An MCP server that connects over HTTP using the Streamable HTTP transport.

Expand Down
7 changes: 5 additions & 2 deletions pydantic_ai_slim/pydantic_ai/toolsets/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ class AbstractToolset(ABC, Generic[AgentDepsT]):

@property
def name(self) -> str:
"""The name of the toolset for use in error messages."""
return self.__class__.__name__.replace('Toolset', ' toolset')
"""A unique name for the toolset.

If you're defining a subclass that can be instantiated by a user, you should let them pass a custom name to the constructor and return that here.
"""
raise NotImplementedError()

@property
def tool_name_conflict_hint(self) -> str:
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/toolsets/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def __post_init__(self):
self._entered_count = 0
self._exit_stack = None

@property
def name(self) -> str:
return f'{self.__class__.__name__}({", ".join(toolset.name for toolset in self.toolsets)})'

async def __aenter__(self) -> Self:
async with self._enter_lock:
if self._entered_count == 0:
Expand Down
Loading
Loading