Skip to content

Commit c05487f

Browse files
committed
feat: replace global _mcp_servers with toolsets
1 parent 6e0267a commit c05487f

File tree

5 files changed

+128
-62
lines changed

5 files changed

+128
-62
lines changed

mcp-run-python/src/main.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,19 @@ The tools are injected into the global namespace automatically - no discovery fu
9696
.array(z.string())
9797
.optional()
9898
.describe('List of available tools for injection (enables tool injection when provided)'),
99+
tool_name_mapping: z
100+
.record(z.string())
101+
.optional()
102+
.describe('Mapping of python_name -> original_mcp_name for tool name conversion'),
99103
},
100104
async ({
101105
python_code,
102106
tools = [],
107+
tool_name_mapping = {},
103108
}: {
104109
python_code: string
105110
tools?: string[]
111+
tool_name_mapping?: Record<string, string>
106112
}) => {
107113
const logPromises: Promise<void>[] = []
108114

@@ -169,6 +175,7 @@ The tools are injected into the global namespace automatically - no discovery fu
169175
{
170176
enableToolInjection: true,
171177
availableTools: tools,
178+
toolNameMapping: tool_name_mapping,
172179
timeoutSeconds: 30,
173180
elicitationCallback,
174181
} as ToolInjectionConfig,

mcp-run-python/src/runCode.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export interface CodeFile {
1313
export interface ToolInjectionConfig {
1414
enableToolInjection: boolean
1515
availableTools: string[]
16+
toolNameMapping?: Record<string, string> // python_name -> original_mcp_name
1617
timeoutSeconds: number
1718
// deno-lint-ignore no-explicit-any
1819
elicitationCallback?: (request: any) => Promise<any>
@@ -157,6 +158,7 @@ function injectToolFunctions(
157158
globals,
158159
config.availableTools,
159160
tool_callback,
161+
config.toolNameMapping,
160162
)
161163

162164
log('info', `Tool injection complete. Available tools: ${config.availableTools.join(', ')}`)

mcp-run-python/src/tool_injection.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,38 +58,52 @@ def _handle_tool_callback_result(result: Any, tool_name: str) -> Any:
5858
return result
5959

6060

61-
def _create_tool_function(tool_name: str, tool_callback: Callable[[Any], Any]) -> Callable[..., Any]:
61+
def _create_tool_function(
62+
tool_name: str, tool_callback: Callable[[Any], Any], globals_dict: dict[str, Any]
63+
) -> Callable[..., Any]:
6264
"""Create a tool function that can be called from Python."""
6365

6466
def tool_function(*args: Any, **kwargs: Any) -> Any:
6567
"""Synchronous tool function that handles the async callback properly."""
66-
# Note: tool_callback is guaranteed to be not None due to check in inject_tool_functions
6768

68-
elicitation_request = _create_elicitation_request(tool_name, args, kwargs)
69+
# Get the actual MCP tool name from the stored mapping
70+
tool_mapping = globals_dict.get('__tool_name_mapping__', {})
71+
actual_tool_name = tool_mapping.get(tool_name, tool_name)
72+
73+
elicitation_request = _create_elicitation_request(actual_tool_name, args, kwargs)
6974

7075
try:
7176
result = tool_callback(elicitation_request)
72-
return _handle_tool_callback_result(result, tool_name)
77+
return _handle_tool_callback_result(result, actual_tool_name)
7378
except Exception as e:
74-
raise Exception(f'Tool {tool_name} failed: {str(e)}')
79+
raise Exception(f'Tool {actual_tool_name} failed: {str(e)}')
7580

7681
return tool_function
7782

7883

7984
def inject_tool_functions(
80-
globals_dict: dict[str, Any], available_tools: list[str], tool_callback: Callable[[Any], Any] | None = None
85+
globals_dict: dict[str, Any],
86+
available_tools: list[str],
87+
tool_callback: Callable[[Any], Any] | None = None,
88+
tool_name_mapping: dict[str, str] | None = None,
8189
) -> None:
8290
"""Inject tool functions into the global namespace.
8391
8492
Args:
8593
globals_dict: Global namespace to inject tools into
86-
available_tools: List of available tool names
94+
available_tools: List of available tool names (should be Python-valid identifiers)
8795
tool_callback: Optional callback for tool execution
96+
tool_name_mapping: Optional mapping of python_name -> original_mcp_name
8897
"""
8998
if not available_tools:
9099
return
91100

92-
# Inject tool functions into globals
101+
# Store the tool name mapping globally for elicitation callback to use
102+
if tool_name_mapping:
103+
globals_dict['__tool_name_mapping__'] = tool_name_mapping
104+
105+
# Inject tool functions into globals using Python-valid names
93106
for tool_name in available_tools:
94107
if tool_callback is not None:
95-
globals_dict[tool_name] = _create_tool_function(tool_name, tool_callback)
108+
# tool_name should already be a valid Python identifier from agent.py
109+
globals_dict[tool_name] = _create_tool_function(tool_name, tool_callback, globals_dict)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 73 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from types import FrameType
1313
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
1414

15+
from mcp import types as mcp_types
1516
from opentelemetry.trace import NoOpTracer, use_span
1617
from pydantic.json_schema import GenerateJsonSchema
1718
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
@@ -70,7 +71,6 @@
7071
from fasta2a.broker import Broker
7172
from fasta2a.schema import AgentProvider, Skill
7273
from fasta2a.storage import Storage
73-
from mcp import types as mcp_types
7474
from starlette.middleware import Middleware
7575
from starlette.routing import BaseRoute, Route
7676
from starlette.types import ExceptionHandler, Lifespan
@@ -1814,6 +1814,20 @@ async def __aenter__(self) -> Self:
18141814
async with self._enter_lock:
18151815
if self._entered_count == 0:
18161816
self._exit_stack = AsyncExitStack()
1817+
1818+
for toolset in self._user_toolsets:
1819+
if isinstance(toolset, MCPServer):
1820+
if (
1821+
hasattr(toolset, 'allow_elicitation')
1822+
and toolset.allow_elicitation
1823+
and toolset.elicitation_callback is None
1824+
):
1825+
toolset.elicitation_callback = self._create_elicitation_callback()
1826+
1827+
# Also setup auto-tool-injection for run_python_code if not already set
1828+
if toolset.process_tool_call is None:
1829+
toolset.process_tool_call = self._create_auto_tool_injection_callback()
1830+
18171831
toolset = self._get_toolset()
18181832
await self._exit_stack.enter_async_context(toolset)
18191833
self._entered_count += 1
@@ -1866,23 +1880,36 @@ async def elicitation_callback(context: Any, params: Any) -> Any:
18661880

18671881
return mcp_types.ElicitResult(action='accept', content={'result': str(result)})
18681882

1869-
# Try MCP tools with name mapping
1870-
actual_tool_name = tool_name.replace('_', '-')
1871-
1883+
# Find the MCP server that has this tool
1884+
target_server = None
18721885
for toolset in self._user_toolsets:
18731886
if not isinstance(toolset, MCPServer):
18741887
continue
1875-
mcp_server = toolset
1876-
if 'mcp-run-python' in str(mcp_server):
1888+
if 'mcp-run-python' in str(toolset):
18771889
continue
18781890

1891+
# Check if this server has the tool
18791892
try:
1880-
result = await mcp_server.direct_call_tool(actual_tool_name, tool_arguments)
1881-
return mcp_types.ElicitResult(action='accept', content={'result': str(result)})
1893+
server_tools = await toolset.list_tools()
1894+
for tool_def in server_tools:
1895+
if tool_def.name == tool_name:
1896+
target_server = toolset
1897+
break
1898+
if target_server:
1899+
break
18821900
except Exception:
18831901
continue
18841902

1885-
return mcp_types.ErrorData(code=mcp_types.INVALID_PARAMS, message=f'Tool {tool_name} not found')
1903+
if target_server:
1904+
try:
1905+
result = await target_server.direct_call_tool(tool_name, tool_arguments)
1906+
return mcp_types.ElicitResult(action='accept', content={'result': str(result)})
1907+
except Exception as e:
1908+
return mcp_types.ErrorData(
1909+
code=mcp_types.INTERNAL_ERROR, message=f'Tool execution failed: {str(e)}'
1910+
)
1911+
else:
1912+
return mcp_types.ErrorData(code=mcp_types.INVALID_PARAMS, message=f'Tool {tool_name} not found')
18861913

18871914
except Exception as e:
18881915
return mcp_types.ErrorData(code=mcp_types.INTERNAL_ERROR, message=f'Tool execution failed: {str(e)}')
@@ -1900,30 +1927,37 @@ async def auto_inject_tools_callback(
19001927
) -> Any:
19011928
"""Auto-inject available tools into run_python_code calls."""
19021929
if tool_name == 'run_python_code':
1903-
# Auto-inject available tools if not already provided
1904-
if 'tools' not in arguments or not arguments['tools']:
1905-
available_tools: list[str] = []
1906-
1907-
# Add function tools
1908-
available_tools.extend(list(self._function_toolset.tools.keys()))
1909-
1910-
for toolset in self._user_toolsets:
1911-
if not isinstance(toolset, MCPServer):
1912-
continue
1913-
mcp_server = toolset
1914-
if 'mcp-run-python' in str(mcp_server):
1915-
continue
1916-
1917-
try:
1918-
server_tools = await mcp_server.list_tools()
1919-
for tool_def in server_tools:
1920-
python_name = tool_def.name.replace('-', '_')
1921-
available_tools.append(python_name)
1922-
except Exception:
1923-
# Silently continue if we can't get tools from a server
1924-
pass
1925-
1926-
arguments['tools'] = available_tools
1930+
# Always auto-inject all available tools for Python code execution
1931+
available_tools: list[str] = []
1932+
tool_name_mapping: dict[str, str] = {}
1933+
1934+
# Add function tools
1935+
function_tools = list(self._function_toolset.tools.keys())
1936+
available_tools.extend(function_tools)
1937+
for func_tool_name in function_tools:
1938+
tool_name_mapping[func_tool_name] = func_tool_name
1939+
1940+
# Add MCP server tools with proper name conversion
1941+
for toolset in self._user_toolsets:
1942+
if not isinstance(toolset, MCPServer):
1943+
continue
1944+
if 'mcp-run-python' in str(toolset):
1945+
continue
1946+
1947+
try:
1948+
server_tools = await toolset.list_tools()
1949+
for tool_def in server_tools:
1950+
original_name = tool_def.name
1951+
python_name = original_name.replace('-', '_')
1952+
available_tools.append(python_name)
1953+
tool_name_mapping[python_name] = original_name
1954+
except Exception:
1955+
# Silently continue if we can't get tools from a server
1956+
pass
1957+
1958+
# Always provide all available tools and mapping
1959+
arguments['tools'] = available_tools
1960+
arguments['tool_name_mapping'] = tool_name_mapping
19271961

19281962
# Continue with normal processing
19291963
return await call_tool_func(tool_name, arguments, None)
@@ -1954,17 +1988,16 @@ async def run_mcp_servers(
19541988

19551989
for toolset in self._user_toolsets:
19561990
if isinstance(toolset, MCPServer):
1957-
mcp_server = toolset
19581991
if (
1959-
hasattr(mcp_server, 'allow_elicitation')
1960-
and mcp_server.allow_elicitation
1961-
and mcp_server.elicitation_callback is None
1992+
hasattr(toolset, 'allow_elicitation')
1993+
and toolset.allow_elicitation
1994+
and toolset.elicitation_callback is None
19621995
):
1963-
mcp_server.elicitation_callback = self._create_elicitation_callback()
1996+
toolset.elicitation_callback = self._create_elicitation_callback()
19641997

19651998
# Also setup auto-tool-injection for run_python_code if not already set
1966-
if mcp_server.process_tool_call is None:
1967-
mcp_server.process_tool_call = self._create_auto_tool_injection_callback()
1999+
if toolset.process_tool_call is None:
2000+
toolset.process_tool_call = self._create_auto_tool_injection_callback()
19682001

19692002
async with self:
19702003
yield

tests/test_mcp_elicitation.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,15 @@ async def mock_elicitation(
8282
)
8383

8484
model = TestModel(custom_output_text='Test response')
85-
agent = Agent(model, mcp_servers=[server])
85+
agent = Agent(model, toolsets=[server])
8686

8787
# Verify the server is properly configured
88-
assert len(agent._mcp_servers) == 1 # type: ignore
89-
assert agent._mcp_servers[0].elicitation_callback is mock_elicitation # type: ignore
88+
toolsets = getattr(agent, '_user_toolsets', [])
89+
mcp_servers = [ts for ts in toolsets if hasattr(ts, '__class__') and 'MCPServer' in ts.__class__.__name__]
90+
assert len(mcp_servers) == 1
91+
# Use getattr to safely check elicitation_callback
92+
callback = getattr(mcp_servers[0], 'elicitation_callback', None)
93+
assert callback is mock_elicitation
9094

9195
async def test_elicitation_callback_error_handling(self):
9296
"""Test error handling in elicitation callback."""
@@ -483,14 +487,19 @@ async def agent_tool_callback(
483487

484488
# Create agent with the MCP server
485489
model = TestModel(custom_output_text='Tool injection test completed')
486-
agent = Agent(model, mcp_servers=[mcp_server])
490+
agent = Agent(model, toolsets=[mcp_server])
487491

488492
# Verify the agent has the MCP server with elicitation callback
489-
assert len(agent._mcp_servers) == 1 # type: ignore
490-
assert agent._mcp_servers[0].elicitation_callback is agent_tool_callback # type: ignore
493+
# Note: Using getattr to safely access toolsets for testing
494+
toolsets = getattr(agent, '_user_toolsets', [])
495+
mcp_servers = [ts for ts in toolsets if hasattr(ts, '__class__') and 'MCPServer' in ts.__class__.__name__]
496+
assert len(mcp_servers) == 1
497+
# Use getattr to safely check elicitation_callback
498+
callback = getattr(mcp_servers[0], 'elicitation_callback', None)
499+
assert callback is agent_tool_callback
491500

492501
# Test running agent with MCP servers
493-
async with agent.run_mcp_servers():
502+
async with agent:
494503
# Verify the MCP server is properly integrated
495504
tools = await mcp_server.list_tools()
496505
assert len(tools) == 1
@@ -780,7 +789,7 @@ async def test_mcp_run_python_code_execution(self):
780789

781790
async with server:
782791
# Test basic Python execution
783-
result = await server.call_tool(
792+
result = await server.direct_call_tool(
784793
'run_python_code', {'python_code': 'print("Hello, World!")\n"Hello from Python"'}
785794
)
786795

@@ -827,7 +836,7 @@ async def python_code_callback(
827836
async with server:
828837
# Test Python code execution with tool injection
829838
# This should trigger the elicitation callback when tools are called
830-
result = await server.call_tool(
839+
result = await server.direct_call_tool(
831840
'run_python_code',
832841
{'python_code': 'print("Testing tool injection")', 'tools': ['web_search', 'calculate']},
833842
)
@@ -854,7 +863,7 @@ async def test_mcp_run_python_error_handling(self):
854863

855864
async with server:
856865
# Test Python code with syntax error
857-
result = await server.call_tool('run_python_code', {'python_code': 'print("Missing closing quote)'})
866+
result = await server.direct_call_tool('run_python_code', {'python_code': 'print("Missing closing quote)'})
858867

859868
# Should return error status
860869
assert isinstance(result, str)
@@ -1009,7 +1018,7 @@ async def test_mcp_run_python_with_dependencies(self):
10091018

10101019
async with server:
10111020
# Test code with dependencies
1012-
result = await server.call_tool(
1021+
result = await server.direct_call_tool(
10131022
'run_python_code',
10141023
{
10151024
'python_code': """
@@ -1049,7 +1058,8 @@ async def test_mcp_run_python_with_tool_prefix(self):
10491058
async with server:
10501059
tools = await server.list_tools()
10511060
assert len(tools) == 1
1052-
assert tools[0].name == 'python_run_python_code'
1061+
# list_tools() returns original tool names without prefix
1062+
assert tools[0].name == 'run_python_code'
10531063

10541064
async def test_mcp_run_python_timeout_setting(self):
10551065
"""Test mcp-run-python server with timeout setting."""
@@ -1071,7 +1081,7 @@ async def test_mcp_run_python_timeout_setting(self):
10711081

10721082
async with server:
10731083
# Test basic execution still works
1074-
result = await server.call_tool('run_python_code', {'python_code': 'print("Timeout test")'})
1084+
result = await server.direct_call_tool('run_python_code', {'python_code': 'print("Timeout test")'})
10751085

10761086
assert isinstance(result, str)
10771087
assert '<status>success</status>' in result

0 commit comments

Comments
 (0)