Skip to content

Always enter Toolset context when running agent #2361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 81 additions & 80 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,90 +774,91 @@ async def main():

toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
# This will raise errors for any name conflicts
run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context)

# Merge model settings in order of precedence: run > agent > model
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
model_settings = merge_model_settings(merged_settings, model_settings)
usage_limits = usage_limits or _usage.UsageLimits()
agent_name = self.name or 'agent'
run_span = tracer.start_span(
'agent run',
attributes={
'model_name': model_used.model_name if model_used else 'no-model',
'agent_name': agent_name,
'logfire.msg': f'{agent_name} run',
},
)

async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
parts = [
self._instructions,
*[await func.run(run_context) for func in self._instructions_functions],
]

model_profile = model_used.profile
if isinstance(output_schema, _output.PromptedOutputSchema):
instructions = output_schema.instructions(model_profile.prompted_output_template)
parts.append(instructions)
async with toolset:
run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context)

# Merge model settings in order of precedence: run > agent > model
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
model_settings = merge_model_settings(merged_settings, model_settings)
usage_limits = usage_limits or _usage.UsageLimits()
agent_name = self.name or 'agent'
run_span = tracer.start_span(
'agent run',
attributes={
'model_name': model_used.model_name if model_used else 'no-model',
'agent_name': agent_name,
'logfire.msg': f'{agent_name} run',
},
)

parts = [p for p in parts if p]
if not parts:
return None
return '\n\n'.join(parts).strip()
async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
parts = [
self._instructions,
*[await func.run(run_context) for func in self._instructions_functions],
]

graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
user_deps=deps,
prompt=user_prompt,
new_message_index=new_message_index,
model=model_used,
model_settings=model_settings,
usage_limits=usage_limits,
max_result_retries=self._max_result_retries,
end_strategy=self.end_strategy,
output_schema=output_schema,
output_validators=output_validators,
history_processors=self.history_processors,
tool_manager=run_toolset,
tracer=tracer,
get_instructions=get_instructions,
instrumentation_settings=instrumentation_settings,
)
start_node = _agent_graph.UserPromptNode[AgentDepsT](
user_prompt=user_prompt,
instructions=self._instructions,
instructions_functions=self._instructions_functions,
system_prompts=self._system_prompts,
system_prompt_functions=self._system_prompt_functions,
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
)
model_profile = model_used.profile
if isinstance(output_schema, _output.PromptedOutputSchema):
instructions = output_schema.instructions(model_profile.prompted_output_template)
parts.append(instructions)

parts = [p for p in parts if p]
if not parts:
return None
return '\n\n'.join(parts).strip()

graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
user_deps=deps,
prompt=user_prompt,
new_message_index=new_message_index,
model=model_used,
model_settings=model_settings,
usage_limits=usage_limits,
max_result_retries=self._max_result_retries,
end_strategy=self.end_strategy,
output_schema=output_schema,
output_validators=output_validators,
history_processors=self.history_processors,
tool_manager=run_toolset,
tracer=tracer,
get_instructions=get_instructions,
instrumentation_settings=instrumentation_settings,
)
start_node = _agent_graph.UserPromptNode[AgentDepsT](
user_prompt=user_prompt,
instructions=self._instructions,
instructions_functions=self._instructions_functions,
system_prompts=self._system_prompts,
system_prompt_functions=self._system_prompt_functions,
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
)

try:
async with graph.iter(
start_node,
state=state,
deps=graph_deps,
span=use_span(run_span) if run_span.is_recording() else None,
infer_name=False,
) as graph_run:
agent_run = AgentRun(graph_run)
yield agent_run
if (final_result := agent_run.result) is not None and run_span.is_recording():
if instrumentation_settings and instrumentation_settings.include_content:
run_span.set_attribute(
'final_result',
(
final_result.output
if isinstance(final_result.output, str)
else json.dumps(InstrumentedModel.serialize_any(final_result.output))
),
)
finally:
try:
if instrumentation_settings and run_span.is_recording():
run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings))
async with graph.iter(
start_node,
state=state,
deps=graph_deps,
span=use_span(run_span) if run_span.is_recording() else None,
infer_name=False,
) as graph_run:
agent_run = AgentRun(graph_run)
yield agent_run
if (final_result := agent_run.result) is not None and run_span.is_recording():
if instrumentation_settings and instrumentation_settings.include_content:
run_span.set_attribute(
'final_result',
(
final_result.output
if isinstance(final_result.output, str)
else json.dumps(InstrumentedModel.serialize_any(final_result.output))
),
)
finally:
run_span.end()
try:
if instrumentation_settings and run_span.is_recording():
run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings))
finally:
run_span.end()

def _run_span_end_attributes(
self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings
Expand Down Expand Up @@ -2173,7 +2174,7 @@ async def __anext__(
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
"""Advance to the next node automatically based on the last returned node."""
next_node = await self._graph_run.__anext__()
if _agent_graph.is_agent_node(next_node):
if _agent_graph.is_agent_node(node=next_node):
return next_node
assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
return next_node
Expand Down
43 changes: 42 additions & 1 deletion tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3866,7 +3866,7 @@ async def only_if_plan_presented(
)


async def test_context_manager():
async def test_explicit_context_manager():
try:
from pydantic_ai.mcp import MCPServerStdio
except ImportError: # pragma: lax no cover
Expand All @@ -3886,6 +3886,47 @@ async def test_context_manager():
assert server2.is_running


async def test_implicit_context_manager():
try:
from pydantic_ai.mcp import MCPServerStdio
except ImportError: # pragma: lax no cover
pytest.skip('mcp is not installed')

server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')])
agent = Agent('test', toolsets=[toolset])

async with agent.iter(
user_prompt='Hello',
):
assert server1.is_running
assert server2.is_running


def test_parallel_mcp_calls():
try:
from pydantic_ai.mcp import MCPServerStdio
except ImportError: # pragma: lax no cover
pytest.skip('mcp is not installed')

async def call_tools_parallel(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if len(messages) == 1:
return ModelResponse(
parts=[
ToolCallPart(tool_name='get_none'),
ToolCallPart(tool_name='get_multiple_items'),
]
)
else:
return ModelResponse(parts=[TextPart('finished')])

server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
agent = Agent(FunctionModel(call_tools_parallel), toolsets=[server])
result = agent.run_sync()
assert result.output == snapshot('finished')


def test_set_mcp_sampling_model():
try:
from pydantic_ai.mcp import MCPServerStdio
Expand Down