Skip to content

Commit cee4c16

Browse files
committed
refactor: simplify MCP elicitation by removing tool name mapping complexity
1 parent c05487f commit cee4c16

File tree

4 files changed

+435
-163
lines changed

4 files changed

+435
-163
lines changed

mcp-run-python/src/tool_injection.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,14 @@ def _create_tool_function(
6464
"""Create a tool function that can be called from Python."""
6565

6666
def tool_function(*args: Any, **kwargs: Any) -> Any:
67-
"""Synchronous tool function that handles the async callback properly."""
68-
69-
# Get the actual MCP tool name from the stored mapping
70-
tool_mapping = globals_dict.get('__tool_name_mapping__', {})
71-
actual_tool_name = tool_mapping.get(tool_name, tool_name)
72-
73-
elicitation_request = _create_elicitation_request(actual_tool_name, args, kwargs)
67+
"""Tool function that calls the MCP elicitation callback."""
68+
elicitation_request = _create_elicitation_request(tool_name=tool_name, args=args, kwargs=kwargs)
7469

7570
try:
7671
result = tool_callback(elicitation_request)
77-
return _handle_tool_callback_result(result, actual_tool_name)
72+
return _handle_tool_callback_result(result, tool_name)
7873
except Exception as e:
79-
raise Exception(f'Tool {actual_tool_name} failed: {str(e)}')
74+
raise Exception(f'Tool {tool_name} failed: {str(e)}')
8075

8176
return tool_function
8277

@@ -85,25 +80,18 @@ def inject_tool_functions(
8580
globals_dict: dict[str, Any],
8681
available_tools: list[str],
8782
tool_callback: Callable[[Any], Any] | None = None,
88-
tool_name_mapping: dict[str, str] | None = None,
8983
) -> None:
9084
"""Inject tool functions into the global namespace.
9185
9286
Args:
9387
globals_dict: Global namespace to inject tools into
94-
available_tools: List of available tool names (should be Python-valid identifiers)
88+
available_tools: List of available tool names
9589
tool_callback: Optional callback for tool execution
96-
tool_name_mapping: Optional mapping of python_name -> original_mcp_name
9790
"""
9891
if not available_tools:
9992
return
10093

101-
# Store the tool name mapping globally for elicitation callback to use
102-
if tool_name_mapping:
103-
globals_dict['__tool_name_mapping__'] = tool_name_mapping
104-
105-
# Inject tool functions into globals using Python-valid names
10694
for tool_name in available_tools:
10795
if tool_callback is not None:
108-
# tool_name should already be a valid Python identifier from agent.py
109-
globals_dict[tool_name] = _create_tool_function(tool_name, tool_callback, globals_dict)
96+
python_name = tool_name.replace('-', '_')
97+
globals_dict[python_name] = _create_tool_function(tool_name, tool_callback, globals_dict)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 44 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from types import FrameType
1313
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
1414

15-
from mcp import types as mcp_types
1615
from opentelemetry.trace import NoOpTracer, use_span
1716
from pydantic.json_schema import GenerateJsonSchema
1817
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
@@ -34,7 +33,7 @@
3433
from ._agent_graph import HistoryProcessor
3534
from ._output import OutputToolset
3635
from ._tool_manager import ToolManager
37-
from .mcp import MCPServer
36+
from .mcp import MCPServer, create_auto_tool_injection_callback, create_tool_elicitation_callback
3837
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
3938
from .output import OutputDataT, OutputSpec
4039
from .profiles import ModelProfile
@@ -1716,7 +1715,6 @@ def _get_toolset(
17161715
if self._prepare_output_tools:
17171716
output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools)
17181717
all_toolsets = [output_toolset, *all_toolsets]
1719-
17201718
return CombinedToolset(all_toolsets)
17211719

17221720
def _infer_name(self, function_frame: FrameType | None) -> None:
@@ -1815,19 +1813,6 @@ async def __aenter__(self) -> Self:
18151813
if self._entered_count == 0:
18161814
self._exit_stack = AsyncExitStack()
18171815

1818-
for toolset in self._user_toolsets:
1819-
if isinstance(toolset, MCPServer):
1820-
if (
1821-
hasattr(toolset, 'allow_elicitation')
1822-
and toolset.allow_elicitation
1823-
and toolset.elicitation_callback is None
1824-
):
1825-
toolset.elicitation_callback = self._create_elicitation_callback()
1826-
1827-
# Also setup auto-tool-injection for run_python_code if not already set
1828-
if toolset.process_tool_call is None:
1829-
toolset.process_tool_call = self._create_auto_tool_injection_callback()
1830-
18311816
toolset = self._get_toolset()
18321817
await self._exit_stack.enter_async_context(toolset)
18331818
self._entered_count += 1
@@ -1856,113 +1841,50 @@ def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None:
18561841

18571842
self._get_toolset().apply(_set_sampling_model)
18581843

1859-
def _create_elicitation_callback(self):
1860-
"""Create an elicitation callback that routes to this agent's tools."""
1844+
def set_mcp_elicitation_toolset(self, toolset_for_elicitation: AbstractToolset[Any] | None = None) -> None:
1845+
"""Set the toolset to use for MCP elicitation callbacks.
18611846
1862-
async def elicitation_callback(context: Any, params: Any) -> Any:
1863-
"""Handle elicitation requests by delegating to the agent's tools."""
1864-
try:
1865-
tool_execution_data = json.loads(params.message)
1866-
tool_name = tool_execution_data.get('tool_name')
1867-
tool_arguments = tool_execution_data.get('arguments', {})
1868-
1869-
# Try function tools first
1870-
function_tools = self._function_toolset.tools
1871-
if tool_name in function_tools:
1872-
tool_func = function_tools[tool_name].function_schema.function
1873-
1874-
# Handle both sync and async functions
1875-
1876-
if inspect.iscoroutinefunction(tool_func):
1877-
result = await tool_func(**tool_arguments)
1878-
else:
1879-
result = tool_func(**tool_arguments)
1880-
1881-
return mcp_types.ElicitResult(action='accept', content={'result': str(result)})
1882-
1883-
# Find the MCP server that has this tool
1884-
target_server = None
1885-
for toolset in self._user_toolsets:
1886-
if not isinstance(toolset, MCPServer):
1887-
continue
1888-
if 'mcp-run-python' in str(toolset):
1889-
continue
1890-
1891-
# Check if this server has the tool
1892-
try:
1893-
server_tools = await toolset.list_tools()
1894-
for tool_def in server_tools:
1895-
if tool_def.name == tool_name:
1896-
target_server = toolset
1897-
break
1898-
if target_server:
1899-
break
1900-
except Exception:
1901-
continue
1902-
1903-
if target_server:
1904-
try:
1905-
result = await target_server.direct_call_tool(tool_name, tool_arguments)
1906-
return mcp_types.ElicitResult(action='accept', content={'result': str(result)})
1907-
except Exception as e:
1908-
return mcp_types.ErrorData(
1909-
code=mcp_types.INTERNAL_ERROR, message=f'Tool execution failed: {str(e)}'
1910-
)
1911-
else:
1912-
return mcp_types.ErrorData(code=mcp_types.INVALID_PARAMS, message=f'Tool {tool_name} not found')
1913-
1914-
except Exception as e:
1915-
return mcp_types.ErrorData(code=mcp_types.INTERNAL_ERROR, message=f'Tool execution failed: {str(e)}')
1916-
1917-
return elicitation_callback
1918-
1919-
def _create_auto_tool_injection_callback(self):
1920-
"""Create a callback that auto-injects available tools into run_python_code calls."""
1921-
1922-
async def auto_inject_tools_callback(
1923-
ctx: RunContext[Any],
1924-
call_tool_func: Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[Any]],
1925-
tool_name: str,
1926-
arguments: dict[str, Any],
1927-
) -> Any:
1928-
"""Auto-inject available tools into run_python_code calls."""
1929-
if tool_name == 'run_python_code':
1930-
# Always auto-inject all available tools for Python code execution
1931-
available_tools: list[str] = []
1932-
tool_name_mapping: dict[str, str] = {}
1933-
1934-
# Add function tools
1935-
function_tools = list(self._function_toolset.tools.keys())
1936-
available_tools.extend(function_tools)
1937-
for func_tool_name in function_tools:
1938-
tool_name_mapping[func_tool_name] = func_tool_name
1939-
1940-
# Add MCP server tools with proper name conversion
1941-
for toolset in self._user_toolsets:
1942-
if not isinstance(toolset, MCPServer):
1943-
continue
1944-
if 'mcp-run-python' in str(toolset):
1945-
continue
1946-
1947-
try:
1948-
server_tools = await toolset.list_tools()
1949-
for tool_def in server_tools:
1950-
original_name = tool_def.name
1951-
python_name = original_name.replace('-', '_')
1952-
available_tools.append(python_name)
1953-
tool_name_mapping[python_name] = original_name
1954-
except Exception:
1955-
# Silently continue if we can't get tools from a server
1956-
pass
1957-
1958-
# Always provide all available tools and mapping
1959-
arguments['tools'] = available_tools
1960-
arguments['tool_name_mapping'] = tool_name_mapping
1961-
1962-
# Continue with normal processing
1963-
return await call_tool_func(tool_name, arguments, None)
1964-
1965-
return auto_inject_tools_callback
1847+
This method configures all MCP servers in the agent's toolsets to use the provided
1848+
toolset for handling elicitation requests (tool injection). This enables Python code
1849+
executed via mcp-run-python to call back to the agent's tools.
1850+
1851+
Args:
1852+
toolset_for_elicitation: Toolset to use for tool injection via elicitation.
1853+
If None, uses the agent's complete toolset.
1854+
1855+
Example:
1856+
```python
1857+
agent = Agent('openai:gpt-4o')
1858+
agent.tool(web_search)
1859+
agent.tool(send_email)
1860+
1861+
mcp_server = MCPServerStdio(command='deno', args=[...], allow_elicitation=True)
1862+
agent.add_toolset(mcp_server)
1863+
1864+
# Enable tool injection with all agent tools
1865+
agent.set_mcp_elicitation_toolset()
1866+
1867+
# Or use specific toolset
1868+
custom_toolset = FunctionToolset(web_search)
1869+
agent.set_mcp_elicitation_toolset(custom_toolset)
1870+
```
1871+
"""
1872+
if toolset_for_elicitation is None:
1873+
# Use complete toolset for both elicitation and injection
1874+
toolset_for_elicitation = self._get_toolset()
1875+
1876+
# Set up callbacks for all MCP servers
1877+
def _set_elicitation_toolset(toolset: AbstractToolset[Any]) -> None:
1878+
if isinstance(toolset, MCPServer) and toolset.allow_elicitation:
1879+
# Set up elicitation callback
1880+
if toolset.elicitation_callback is None:
1881+
toolset.elicitation_callback = create_tool_elicitation_callback(toolset=toolset_for_elicitation)
1882+
1883+
# Set up tool injection callback
1884+
if toolset.process_tool_call is None:
1885+
toolset.process_tool_call = create_auto_tool_injection_callback(toolset=toolset_for_elicitation)
1886+
1887+
self._get_toolset().apply(_set_elicitation_toolset)
19661888

19671889
@asynccontextmanager
19681890
@deprecated(
@@ -1984,21 +1906,6 @@ async def run_mcp_servers(
19841906
if model is not None:
19851907
raise
19861908

1987-
# Auto-setup elicitation callback if allow_elicitation is True and no callback is set
1988-
1989-
for toolset in self._user_toolsets:
1990-
if isinstance(toolset, MCPServer):
1991-
if (
1992-
hasattr(toolset, 'allow_elicitation')
1993-
and toolset.allow_elicitation
1994-
and toolset.elicitation_callback is None
1995-
):
1996-
toolset.elicitation_callback = self._create_elicitation_callback()
1997-
1998-
# Also setup auto-tool-injection for run_python_code if not already set
1999-
if toolset.process_tool_call is None:
2000-
toolset.process_tool_call = self._create_auto_tool_injection_callback()
2001-
20021909
async with self:
20031910
yield
20041911

0 commit comments

Comments
 (0)