Skip to content

Commit 6e8c76a

Browse files
committed
WIP: temporalize_agent
1 parent 226830b commit 6e8c76a

File tree

11 files changed

+653
-5
lines changed

11 files changed

+653
-5
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,14 @@ def _get_toolset(
17171717

17181718
return CombinedToolset(all_toolsets)
17191719

1720+
@property
1721+
def toolset(self) -> AbstractToolset[AgentDepsT]:
1722+
"""The complete toolset that will be available to the model during an agent run.
1723+
1724+
This will include function tools registered directly to the agent, output tools, and user-provided toolsets including MCP servers.
1725+
"""
1726+
return self._get_toolset()
1727+
17201728
def _infer_name(self, function_frame: FrameType | None) -> None:
17211729
"""Infer the agent name from the call frame.
17221730

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def label(self) -> str:
126126

127127
@property
128128
def tool_name_conflict_hint(self) -> str:
129-
return 'Consider setting `tool_prefix` to avoid name conflicts.'
129+
return 'Set the `tool_prefix` attribute to avoid name conflicts.'
130130

131131
async def list_tools(self) -> list[mcp_types.Tool]:
132132
"""Retrieve tools that are currently active on the server.
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from datetime import timedelta
5+
from typing import Any, Callable
6+
7+
from temporalio.common import Priority, RetryPolicy
8+
from temporalio.workflow import ActivityCancellationType, VersioningIntent
9+
10+
from pydantic_ai._run_context import AgentDepsT, RunContext
11+
12+
13+
class _TemporalRunContext(RunContext[AgentDepsT]):
14+
_data: dict[str, Any]
15+
16+
def __init__(self, **kwargs: Any):
17+
self._data = kwargs
18+
setattr(
19+
self,
20+
'__dataclass_fields__',
21+
{name: field for name, field in RunContext.__dataclass_fields__.items() if name in kwargs},
22+
)
23+
24+
def __getattribute__(self, name: str) -> Any:
25+
try:
26+
return super().__getattribute__(name)
27+
except AttributeError as e:
28+
data = super().__getattribute__('_data')
29+
if name in data:
30+
return data[name]
31+
raise e # TODO: Explain how to make a new run context attribute available
32+
33+
@classmethod
34+
def serialize_run_context(cls, ctx: RunContext[AgentDepsT]) -> dict[str, Any]:
35+
return {
36+
'deps': ctx.deps,
37+
'retries': ctx.retries,
38+
'tool_call_id': ctx.tool_call_id,
39+
'tool_name': ctx.tool_name,
40+
'retry': ctx.retry,
41+
'run_step': ctx.run_step,
42+
}
43+
44+
@classmethod
45+
def deserialize_run_context(cls, ctx: dict[str, Any]) -> RunContext[AgentDepsT]:
46+
return cls(**ctx)
47+
48+
49+
@dataclass
50+
class TemporalSettings:
51+
"""Settings for Temporal `execute_activity` and Pydantic AI-specific Temporal activity behavior."""
52+
53+
# Temporal settings
54+
task_queue: str | None = None
55+
schedule_to_close_timeout: timedelta | None = None
56+
schedule_to_start_timeout: timedelta | None = None
57+
start_to_close_timeout: timedelta | None = None
58+
heartbeat_timeout: timedelta | None = None
59+
retry_policy: RetryPolicy | None = None
60+
cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL
61+
activity_id: str | None = None
62+
versioning_intent: VersioningIntent | None = None
63+
summary: str | None = None
64+
priority: Priority = Priority.default
65+
66+
# Pydantic AI specific
67+
tool_settings: dict[str, dict[str, TemporalSettings]] | None = None
68+
69+
def for_tool(self, toolset_id: str, tool_id: str) -> TemporalSettings:
70+
if self.tool_settings is None:
71+
return self
72+
return self.tool_settings.get(toolset_id, {}).get(tool_id, self)
73+
74+
serialize_run_context: Callable[[RunContext], Any] = _TemporalRunContext.serialize_run_context
75+
deserialize_run_context: Callable[[dict[str, Any]], RunContext] = _TemporalRunContext.deserialize_run_context
76+
77+
@property
78+
def execute_activity_kwargs(self) -> dict[str, Any]:
79+
return {
80+
'task_queue': self.task_queue,
81+
'schedule_to_close_timeout': self.schedule_to_close_timeout,
82+
'schedule_to_start_timeout': self.schedule_to_start_timeout,
83+
'start_to_close_timeout': self.start_to_close_timeout,
84+
'heartbeat_timeout': self.heartbeat_timeout,
85+
'retry_policy': self.retry_policy,
86+
'cancellation_type': self.cancellation_type,
87+
'activity_id': self.activity_id,
88+
'versioning_intent': self.versioning_intent,
89+
'summary': self.summary,
90+
'priority': self.priority,
91+
}
92+
93+
94+
def initialize_temporal():
95+
"""Explicitly import types without which Temporal will not be able to serialize/deserialize `ModelMessage`s."""
96+
from pydantic_ai.messages import ( # noqa F401
97+
ModelResponse, # pyright: ignore[reportUnusedImport]
98+
ImageUrl, # pyright: ignore[reportUnusedImport]
99+
AudioUrl, # pyright: ignore[reportUnusedImport]
100+
DocumentUrl, # pyright: ignore[reportUnusedImport]
101+
VideoUrl, # pyright: ignore[reportUnusedImport]
102+
BinaryContent, # pyright: ignore[reportUnusedImport]
103+
UserContent, # pyright: ignore[reportUnusedImport]
104+
)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Callable
4+
5+
from pydantic_ai.agent import Agent
6+
from pydantic_ai.mcp import MCPServer
7+
from pydantic_ai.toolsets.abstract import AbstractToolset
8+
from pydantic_ai.toolsets.function import FunctionToolset
9+
10+
from ..models import Model
11+
from . import TemporalSettings
12+
from .function_toolset import temporalize_function_toolset
13+
from .mcp_server import temporalize_mcp_server
14+
from .model import temporalize_model
15+
16+
17+
def temporalize_toolset(toolset: AbstractToolset, settings: TemporalSettings | None) -> list[Callable[..., Any]]:
18+
"""Temporalize a toolset.
19+
20+
Args:
21+
toolset: The toolset to temporalize.
22+
settings: The temporal settings to use.
23+
"""
24+
if isinstance(toolset, FunctionToolset):
25+
return temporalize_function_toolset(toolset, settings)
26+
elif isinstance(toolset, MCPServer):
27+
return temporalize_mcp_server(toolset, settings)
28+
else:
29+
return []
30+
31+
32+
def temporalize_agent(
33+
agent: Agent,
34+
settings: TemporalSettings | None = None,
35+
temporalize_toolset_func: Callable[
36+
[AbstractToolset, TemporalSettings | None], list[Callable[..., Any]]
37+
] = temporalize_toolset,
38+
) -> list[Callable[..., Any]]:
39+
"""Temporalize an agent.
40+
41+
Args:
42+
agent: The agent to temporalize.
43+
settings: The temporal settings to use.
44+
temporalize_toolset_func: The function to use to temporalize the toolsets.
45+
"""
46+
if existing_activities := getattr(agent, '__temporal_activities', None):
47+
return existing_activities
48+
49+
settings = settings or TemporalSettings()
50+
51+
# TODO: Doesn't consider model/toolsets passed at iter time.
52+
53+
activities: list[Callable[..., Any]] = []
54+
if isinstance(agent.model, Model):
55+
activities.extend(temporalize_model(agent.model, settings))
56+
57+
def temporalize_toolset(toolset: AbstractToolset) -> None:
58+
activities.extend(temporalize_toolset_func(toolset, settings))
59+
60+
agent.toolset.apply(temporalize_toolset)
61+
62+
setattr(agent, '__temporal_activities', activities)
63+
return activities
64+
65+
66+
# TODO: untemporalize_agent
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Callable
5+
6+
from pydantic import ConfigDict, with_config
7+
from temporalio import activity, workflow
8+
9+
from pydantic_ai.toolsets.function import FunctionToolset
10+
11+
from .._run_context import RunContext
12+
from ..toolsets import ToolsetTool
13+
from . import TemporalSettings
14+
15+
16+
@dataclass
17+
@with_config(ConfigDict(arbitrary_types_allowed=True))
18+
class _CallToolParams:
19+
name: str
20+
tool_args: dict[str, Any]
21+
serialized_run_context: Any
22+
23+
24+
def temporalize_function_toolset(
25+
toolset: FunctionToolset,
26+
settings: TemporalSettings | None = None,
27+
) -> list[Callable[..., Any]]:
28+
"""Temporalize a function toolset.
29+
30+
Args:
31+
toolset: The function toolset to temporalize.
32+
settings: The temporal settings to use.
33+
"""
34+
if activities := getattr(toolset, '__temporal_activities', None):
35+
return activities
36+
37+
id = toolset.id
38+
if not id:
39+
raise ValueError(
40+
"A function toolset needs to have an ID in order to be used in a durable execution environment like Temporal. The ID will be used to identify the toolset's activities within the workflow."
41+
)
42+
43+
settings = settings or TemporalSettings()
44+
45+
original_call_tool = toolset.call_tool
46+
47+
@activity.defn(name=f'function_toolset__{id}__call_tool')
48+
async def call_tool_activity(params: _CallToolParams) -> Any:
49+
name = params.name
50+
ctx = settings.for_tool(id, name).deserialize_run_context(params.serialized_run_context)
51+
tool = (await toolset.get_tools(ctx))[name]
52+
return await original_call_tool(name, params.tool_args, ctx, tool)
53+
54+
async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any:
55+
tool_settings = settings.for_tool(id, name)
56+
serialized_run_context = tool_settings.serialize_run_context(ctx)
57+
return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
58+
activity=call_tool_activity,
59+
arg=_CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context),
60+
**tool_settings.execute_activity_kwargs,
61+
)
62+
63+
toolset.call_tool = call_tool
64+
65+
activities = [call_tool_activity]
66+
setattr(toolset, '__temporal_activities', activities)
67+
return activities
68+
69+
70+
# class TemporalFunctionToolset(FunctionToolset[AgentDepsT]):
71+
# def __init__(
72+
# self,
73+
# tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
74+
# max_retries: int = 1,
75+
# temporal_settings: TemporalSettings | None = None,
76+
# serialize_run_context: Callable[[RunContext[AgentDepsT]], Any] | None = None,
77+
# deserialize_run_context: Callable[[Any], RunContext[AgentDepsT]] | None = None,
78+
# ):
79+
# super().__init__(tools, max_retries)
80+
# self.temporal_settings = temporal_settings or TemporalSettings()
81+
# self.serialize_run_context = serialize_run_context or TemporalRunContext[AgentDepsT].serialize_run_context
82+
# self.deserialize_run_context = deserialize_run_context or TemporalRunContext[AgentDepsT].deserialize_run_context
83+
84+
# @activity.defn(name='function_toolset_call_tool')
85+
# async def call_tool_activity(params: FunctionCallToolParams) -> Any:
86+
# ctx = self.deserialize_run_context(params.serialized_run_context)
87+
# tool = (await self.get_tools(ctx))[params.name]
88+
# return await FunctionToolset[AgentDepsT].call_tool(self, params.name, params.tool_args, ctx, tool)
89+
90+
# self.call_tool_activity = call_tool_activity
91+
92+
# async def call_tool(
93+
# self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
94+
# ) -> Any:
95+
# serialized_run_context = self.serialize_run_context(ctx)
96+
# return await workflow.execute_activity(
97+
# activity=self.call_tool_activity,
98+
# arg=FunctionCallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context),
99+
# **self.temporal_settings.__dict__,
100+
# )

0 commit comments

Comments
 (0)