Skip to content

Commit 5f6cfa7

Browse files
committed
Streaming with Temporal
1 parent 2f04894 commit 5f6cfa7

File tree

6 files changed

+169
-151
lines changed

6 files changed

+169
-151
lines changed

pydantic_ai_slim/pydantic_ai/temporal/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@
1111

1212

1313
class _TemporalRunContext(RunContext[AgentDepsT]):
14-
_data: dict[str, Any]
15-
1614
def __init__(self, **kwargs: Any):
17-
self._data = kwargs
15+
self.__dict__ = kwargs
1816
setattr(
1917
self,
2018
'__dataclass_fields__',
@@ -25,10 +23,12 @@ def __getattribute__(self, name: str) -> Any:
2523
try:
2624
return super().__getattribute__(name)
2725
except AttributeError as e:
28-
data = super().__getattribute__('_data')
29-
if name in data:
30-
return data[name]
31-
raise e # TODO: Explain how to make a new run context attribute available
26+
if name in RunContext.__dataclass_fields__:
27+
raise AttributeError(
28+
f'Temporalized {RunContext.__name__!r} object has no attribute {name!r}. To make the attribute available, pass a `TemporalSettings` object to `temporalize_agent` that has a custom `serialize_run_context` function that returns a dictionary that includes the attribute.'
29+
)
30+
else:
31+
raise e
3232

3333
@classmethod
3434
def serialize_run_context(cls, ctx: RunContext[AgentDepsT]) -> dict[str, Any]:
@@ -75,7 +75,7 @@ def for_tool(self, toolset_id: str, tool_id: str) -> TemporalSettings:
7575
deserialize_run_context: Callable[[dict[str, Any]], RunContext] = _TemporalRunContext.deserialize_run_context
7676

7777
@property
78-
def execute_activity_kwargs(self) -> dict[str, Any]:
78+
def execute_activity_options(self) -> dict[str, Any]:
7979
return {
8080
'task_queue': self.task_queue,
8181
'schedule_to_close_timeout': self.schedule_to_close_timeout,

pydantic_ai_slim/pydantic_ai/temporal/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def temporalize_toolset(toolset: AbstractToolset, settings: TemporalSettings | N
3030

3131

3232
def temporalize_agent(
33-
agent: Agent,
33+
agent: Agent[Any, Any],
3434
settings: TemporalSettings | None = None,
3535
temporalize_toolset_func: Callable[
3636
[AbstractToolset, TemporalSettings | None], list[Callable[..., Any]]
@@ -52,7 +52,7 @@ def temporalize_agent(
5252

5353
activities: list[Callable[..., Any]] = []
5454
if isinstance(agent.model, Model):
55-
activities.extend(temporalize_model(agent.model, settings))
55+
activities.extend(temporalize_model(agent.model, settings, agent._event_stream_handler)) # pyright: ignore[reportPrivateUsage]
5656

5757
def temporalize_toolset(toolset: AbstractToolset) -> None:
5858
activities.extend(temporalize_toolset_func(toolset, settings))

pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -57,44 +57,11 @@ async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool:
5757
return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
5858
activity=call_tool_activity,
5959
arg=_CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context),
60-
**tool_settings.execute_activity_kwargs,
60+
**tool_settings.execute_activity_options,
6161
)
6262

6363
toolset.call_tool = call_tool
6464

6565
activities = [call_tool_activity]
6666
setattr(toolset, '__temporal_activities', activities)
6767
return activities
68-
69-
70-
# class TemporalFunctionToolset(FunctionToolset[AgentDepsT]):
71-
# def __init__(
72-
# self,
73-
# tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
74-
# max_retries: int = 1,
75-
# temporal_settings: TemporalSettings | None = None,
76-
# serialize_run_context: Callable[[RunContext[AgentDepsT]], Any] | None = None,
77-
# deserialize_run_context: Callable[[Any], RunContext[AgentDepsT]] | None = None,
78-
# ):
79-
# super().__init__(tools, max_retries)
80-
# self.temporal_settings = temporal_settings or TemporalSettings()
81-
# self.serialize_run_context = serialize_run_context or TemporalRunContext[AgentDepsT].serialize_run_context
82-
# self.deserialize_run_context = deserialize_run_context or TemporalRunContext[AgentDepsT].deserialize_run_context
83-
84-
# @activity.defn(name='function_toolset_call_tool')
85-
# async def call_tool_activity(params: FunctionCallToolParams) -> Any:
86-
# ctx = self.deserialize_run_context(params.serialized_run_context)
87-
# tool = (await self.get_tools(ctx))[params.name]
88-
# return await FunctionToolset[AgentDepsT].call_tool(self, params.name, params.tool_args, ctx, tool)
89-
90-
# self.call_tool_activity = call_tool_activity
91-
92-
# async def call_tool(
93-
# self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
94-
# ) -> Any:
95-
# serialized_run_context = self.serialize_run_context(ctx)
96-
# return await workflow.execute_activity(
97-
# activity=self.call_tool_activity,
98-
# arg=FunctionCallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context),
99-
# **self.temporal_settings.__dict__,
100-
# )

pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ async def call_tool_activity(params: _CallToolParams) -> ToolResult:
5555
async def list_tools() -> list[mcp_types.Tool]:
5656
return await workflow.execute_activity( # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
5757
activity=list_tools_activity,
58-
**settings.execute_activity_kwargs,
58+
**settings.execute_activity_options,
5959
)
6060

6161
async def direct_call_tool(
@@ -66,7 +66,7 @@ async def direct_call_tool(
6666
return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
6767
activity=call_tool_activity,
6868
arg=_CallToolParams(name=name, tool_args=args, metadata=metadata),
69-
**settings.for_tool(id, name).execute_activity_kwargs,
69+
**settings.for_tool(id, name).execute_activity_options,
7070
)
7171

7272
server.list_tools = list_tools
@@ -75,47 +75,3 @@ async def direct_call_tool(
7575
activities = [list_tools_activity, call_tool_activity]
7676
setattr(server, '__temporal_activities', activities)
7777
return activities
78-
79-
80-
# class TemporalMCPServer(WrapperToolset[Any]):
81-
# temporal_settings: TemporalSettings
82-
83-
# @property
84-
# def wrapped_server(self) -> MCPServer:
85-
# assert isinstance(self.wrapped, MCPServer)
86-
# return self.wrapped
87-
88-
# def __init__(self, wrapped: MCPServer, temporal_settings: TemporalSettings | None = None):
89-
# assert isinstance(self.wrapped, MCPServer)
90-
# super().__init__(wrapped)
91-
# self.temporal_settings = temporal_settings or TemporalSettings()
92-
93-
# @activity.defn(name='mcp_server_list_tools')
94-
# async def list_tools_activity() -> list[mcp_types.Tool]:
95-
# return await self.wrapped_server.list_tools()
96-
97-
# self.list_tools_activity = list_tools_activity
98-
99-
# @activity.defn(name='mcp_server_call_tool')
100-
# async def call_tool_activity(params: MCPCallToolParams) -> ToolResult:
101-
# return await self.wrapped_server.direct_call_tool(params.name, params.tool_args, params.metadata)
102-
103-
# self.call_tool_activity = call_tool_activity
104-
105-
# async def list_tools(self) -> list[mcp_types.Tool]:
106-
# return await workflow.execute_activity(
107-
# activity=self.list_tools_activity,
108-
# **self.temporal_settings.__dict__,
109-
# )
110-
111-
# async def direct_call_tool(
112-
# self,
113-
# name: str,
114-
# args: dict[str, Any],
115-
# metadata: dict[str, Any] | None = None,
116-
# ) -> ToolResult:
117-
# return await workflow.execute_activity(
118-
# activity=self.call_tool_activity,
119-
# arg=MCPCallToolParams(name=name, tool_args=args, metadata=metadata),
120-
# **self.temporal_settings.__dict__,
121-
# )

0 commit comments

Comments
 (0)