Skip to content

Commit 5709c23

Browse files
strawgateDouweM
andauthored
Support sequential tool calling (#2718)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 5f9d8de commit 5709c23

File tree

9 files changed

+177
-31
lines changed

9 files changed

+177
-31
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import asyncio
44
import dataclasses
5+
import inspect
6+
from asyncio import Task
57
from collections import defaultdict, deque
68
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
79
from contextlib import asynccontextmanager, contextmanager
@@ -740,7 +742,6 @@ async def process_function_tools( # noqa: C901
740742
deferred_tool_results: dict[str, DeferredToolResult] = {}
741743
if build_run_context(ctx).tool_call_approved and ctx.deps.tool_call_results is not None:
742744
deferred_tool_results = ctx.deps.tool_call_results
743-
744745
# Deferred tool calls are "run" as well, by reading their value from the tool call results
745746
calls_to_run.extend(tool_calls_by_kind['external'])
746747
calls_to_run.extend(tool_calls_by_kind['unapproved'])
@@ -819,47 +820,65 @@ async def _call_tools(
819820
for call in tool_calls:
820821
yield _messages.FunctionToolCallEvent(call)
821822

822-
# Run all tool tasks in parallel
823823
with tracer.start_as_current_span(
824824
'running tools',
825825
attributes={
826826
'tools': [call.tool_name for call in tool_calls],
827827
'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}',
828828
},
829829
):
830-
tasks = [
831-
asyncio.create_task(
832-
_call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id), usage_limits),
833-
name=call.tool_name,
834-
)
835-
for call in tool_calls
836-
]
837-
838-
pending = tasks
839-
while pending:
840-
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
841-
for task in done:
842-
index = tasks.index(task)
843-
try:
844-
tool_part, tool_user_part = task.result()
845-
except exceptions.CallDeferred:
846-
deferred_calls_by_index[index] = 'external'
847-
except exceptions.ApprovalRequired:
848-
deferred_calls_by_index[index] = 'unapproved'
849-
else:
850-
yield _messages.FunctionToolResultEvent(tool_part)
851830

852-
tool_parts_by_index[index] = tool_part
853-
if tool_user_part:
854-
user_parts_by_index[index] = tool_user_part
831+
async def handle_call_or_result(
832+
coro_or_task: Awaitable[
833+
tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]
834+
]
835+
| Task[tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]],
836+
index: int,
837+
) -> _messages.HandleResponseEvent | None:
838+
try:
839+
tool_part, tool_user_part = (
840+
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
841+
)
842+
except exceptions.CallDeferred:
843+
deferred_calls_by_index[index] = 'external'
844+
except exceptions.ApprovalRequired:
845+
deferred_calls_by_index[index] = 'unapproved'
846+
else:
847+
tool_parts_by_index[index] = tool_part
848+
if tool_user_part:
849+
user_parts_by_index[index] = tool_user_part
850+
851+
return _messages.FunctionToolResultEvent(tool_part)
852+
853+
if tool_manager.should_call_sequentially(tool_calls):
854+
for index, call in enumerate(tool_calls):
855+
if event := await handle_call_or_result(
856+
_call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id), usage_limits),
857+
index,
858+
):
859+
yield event
860+
861+
else:
862+
tasks = [
863+
asyncio.create_task(
864+
_call_tool(tool_manager, call, deferred_tool_results.get(call.tool_call_id), usage_limits),
865+
name=call.tool_name,
866+
)
867+
for call in tool_calls
868+
]
869+
870+
pending = tasks
871+
while pending:
872+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
873+
for task in done:
874+
index = tasks.index(task)
875+
if event := await handle_call_or_result(coro_or_task=task, index=index):
876+
yield event
855877

856878
# We append the results at the end, rather than as they are received, to retain a consistent ordering
857879
# This is mostly just to simplify testing
858-
for k in sorted(tool_parts_by_index):
859-
output_parts.append(tool_parts_by_index[k])
860-
861-
for k in sorted(user_parts_by_index):
862-
output_parts.append(user_parts_by_index[k])
880+
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
881+
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])
863882

864883
for k in sorted(deferred_calls_by_index):
865884
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def tool_defs(self) -> list[ToolDefinition]:
5656

5757
return [tool.tool_def for tool in self.tools.values()]
5858

59+
def should_call_sequentially(self, calls: list[ToolCallPart]) -> bool:
60+
"""Whether to require sequential tool calls for a list of tool calls."""
61+
return any(tool_def.sequential for call in calls if (tool_def := self.get_tool_def(call.tool_name)))
62+
5963
def get_tool_def(self, name: str) -> ToolDefinition | None:
6064
"""Get the tool definition for a given tool name, or `None` if the tool is unknown."""
6165
if self.tools is None:

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,7 @@ def tool_plain(
11191119
require_parameter_descriptions: bool = False,
11201120
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
11211121
strict: bool | None = None,
1122+
sequential: bool = False,
11221123
requires_approval: bool = False,
11231124
) -> Any:
11241125
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
@@ -1164,6 +1165,7 @@ async def spam(ctx: RunContext[str]) -> float:
11641165
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
11651166
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
11661167
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1168+
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
11671169
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
11681170
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
11691171
"""
@@ -1180,6 +1182,7 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams
11801182
require_parameter_descriptions,
11811183
schema_generator,
11821184
strict,
1185+
sequential,
11831186
requires_approval,
11841187
)
11851188
return func_

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ class Tool(Generic[AgentDepsT]):
253253
docstring_format: DocstringFormat
254254
require_parameter_descriptions: bool
255255
strict: bool | None
256+
sequential: bool
256257
requires_approval: bool
257258
function_schema: _function_schema.FunctionSchema
258259
"""
@@ -274,6 +275,7 @@ def __init__(
274275
require_parameter_descriptions: bool = False,
275276
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
276277
strict: bool | None = None,
278+
sequential: bool = False,
277279
requires_approval: bool = False,
278280
function_schema: _function_schema.FunctionSchema | None = None,
279281
):
@@ -327,6 +329,7 @@ async def prep_my_tool(
327329
schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`.
328330
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
329331
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
332+
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
330333
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
331334
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
332335
function_schema: The function schema to use for the tool. If not provided, it will be generated.
@@ -347,6 +350,7 @@ async def prep_my_tool(
347350
self.docstring_format = docstring_format
348351
self.require_parameter_descriptions = require_parameter_descriptions
349352
self.strict = strict
353+
self.sequential = sequential
350354
self.requires_approval = requires_approval
351355

352356
@classmethod
@@ -357,6 +361,7 @@ def from_schema(
357361
description: str | None,
358362
json_schema: JsonSchemaValue,
359363
takes_ctx: bool = False,
364+
sequential: bool = False,
360365
) -> Self:
361366
"""Creates a Pydantic tool from a function and a JSON schema.
362367
@@ -370,6 +375,7 @@ def from_schema(
370375
json_schema: The schema for the function arguments
371376
takes_ctx: An optional boolean parameter indicating whether the function
372377
accepts the context object as an argument.
378+
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
373379
374380
Returns:
375381
A Pydantic tool that calls the function
@@ -389,6 +395,7 @@ def from_schema(
389395
name=name,
390396
description=description,
391397
function_schema=function_schema,
398+
sequential=sequential,
392399
)
393400

394401
@property
@@ -398,6 +405,7 @@ def tool_def(self):
398405
description=self.description,
399406
parameters_json_schema=self.function_schema.json_schema,
400407
strict=self.strict,
408+
sequential=self.sequential,
401409
)
402410

403411
async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
@@ -466,6 +474,9 @@ class ToolDefinition:
466474
Note: this is currently only supported by OpenAI models.
467475
"""
468476

477+
sequential: bool = False
478+
"""Whether this tool requires a sequential/serial execution environment."""
479+
469480
kind: ToolKind = field(default='function')
470481
"""The kind of tool:
471482

pydantic_ai_slim/pydantic_ai/toolsets/function.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def tool(
9797
require_parameter_descriptions: bool | None = None,
9898
schema_generator: type[GenerateJsonSchema] | None = None,
9999
strict: bool | None = None,
100+
sequential: bool = False,
100101
requires_approval: bool = False,
101102
) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ...
102103

@@ -112,6 +113,7 @@ def tool(
112113
require_parameter_descriptions: bool | None = None,
113114
schema_generator: type[GenerateJsonSchema] | None = None,
114115
strict: bool | None = None,
116+
sequential: bool = False,
115117
requires_approval: bool = False,
116118
) -> Any:
117119
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
@@ -161,6 +163,7 @@ async def spam(ctx: RunContext[str], y: float) -> float:
161163
If `None`, the default value is determined by the toolset.
162164
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
163165
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
166+
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
164167
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
165168
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
166169
"""
@@ -179,6 +182,7 @@ def tool_decorator(
179182
require_parameter_descriptions,
180183
schema_generator,
181184
strict,
185+
sequential,
182186
requires_approval,
183187
)
184188
return func_
@@ -196,6 +200,7 @@ def add_function(
196200
require_parameter_descriptions: bool | None = None,
197201
schema_generator: type[GenerateJsonSchema] | None = None,
198202
strict: bool | None = None,
203+
sequential: bool = False,
199204
requires_approval: bool = False,
200205
) -> None:
201206
"""Add a function as a tool to the toolset.
@@ -222,6 +227,7 @@ def add_function(
222227
If `None`, the default value is determined by the toolset.
223228
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
224229
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
230+
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
225231
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
226232
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
227233
"""
@@ -242,6 +248,7 @@ def add_function(
242248
require_parameter_descriptions=require_parameter_descriptions,
243249
schema_generator=schema_generator,
244250
strict=strict,
251+
sequential=sequential,
245252
requires_approval=requires_approval,
246253
)
247254
self.add_tool(tool)

tests/test_agent.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4293,6 +4293,65 @@ async def call_tools_parallel(messages: list[ModelMessage], info: AgentInfo) ->
42934293
assert result.output == snapshot('finished')
42944294

42954295

4296+
def test_sequential_calls():
4297+
"""Test that tool calls are executed correctly when a `sequential` tool is present in the call."""
4298+
4299+
async def call_tools_sequential(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
4300+
return ModelResponse(
4301+
parts=[
4302+
ToolCallPart(tool_name='call_first'),
4303+
ToolCallPart(tool_name='call_first'),
4304+
ToolCallPart(tool_name='call_first'),
4305+
ToolCallPart(tool_name='call_first'),
4306+
ToolCallPart(tool_name='call_first'),
4307+
ToolCallPart(tool_name='call_first'),
4308+
ToolCallPart(tool_name='increment_integer_holder'),
4309+
ToolCallPart(tool_name='requires_approval'),
4310+
ToolCallPart(tool_name='call_second'),
4311+
ToolCallPart(tool_name='call_second'),
4312+
ToolCallPart(tool_name='call_second'),
4313+
ToolCallPart(tool_name='call_second'),
4314+
ToolCallPart(tool_name='call_second'),
4315+
ToolCallPart(tool_name='call_second'),
4316+
ToolCallPart(tool_name='call_second'),
4317+
]
4318+
)
4319+
4320+
sequential_toolset = FunctionToolset()
4321+
4322+
integer_holder: int = 1
4323+
4324+
@sequential_toolset.tool(sequential=False)
4325+
def call_first():
4326+
nonlocal integer_holder
4327+
assert integer_holder == 1
4328+
4329+
@sequential_toolset.tool(sequential=True)
4330+
def increment_integer_holder():
4331+
nonlocal integer_holder
4332+
integer_holder = 2
4333+
4334+
@sequential_toolset.tool()
4335+
def requires_approval():
4336+
from pydantic_ai.exceptions import ApprovalRequired
4337+
4338+
raise ApprovalRequired()
4339+
4340+
@sequential_toolset.tool(sequential=False)
4341+
def call_second():
4342+
nonlocal integer_holder
4343+
assert integer_holder == 2
4344+
4345+
agent = Agent(
4346+
FunctionModel(call_tools_sequential), toolsets=[sequential_toolset], output_type=[str, DeferredToolRequests]
4347+
)
4348+
result = agent.run_sync()
4349+
assert result.output == snapshot(
4350+
DeferredToolRequests(approvals=[ToolCallPart(tool_name='requires_approval', tool_call_id=IsStr())])
4351+
)
4352+
assert integer_holder == 2
4353+
4354+
42964355
def test_set_mcp_sampling_model():
42974356
try:
42984357
from pydantic_ai.mcp import MCPServerStdio

tests/test_logfire.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ async def my_ret(x: int) -> str:
385385
},
386386
'outer_typed_dict_key': None,
387387
'strict': None,
388+
'sequential': False,
388389
'kind': 'function',
389390
}
390391
],
@@ -777,6 +778,7 @@ class MyOutput:
777778
'description': 'The final response which ends this conversation',
778779
'outer_typed_dict_key': None,
779780
'strict': None,
781+
'sequential': False,
780782
'kind': 'output',
781783
}
782784
],

0 commit comments

Comments
 (0)