Skip to content

Let Agent be run in a Temporal workflow by moving model requests, tool calls, and MCP to Temporal activities #2225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3aef43d
Let Agent be run in a Temporal workflow by moving model requests, too…
DouweM Aug 6, 2025
d7bffad
Update for Agent.toolsets
DouweM Aug 8, 2025
07d21d1
Fix docstring examples
DouweM Aug 8, 2025
ba8f37a
Merge branch 'main' into temporal-agent
DouweM Aug 8, 2025
58691ab
Address feedback
DouweM Aug 8, 2025
f19ee91
Add simple Temporal test
DouweM Aug 8, 2025
9642a15
Fix temporal tests
DouweM Aug 8, 2025
ae07963
Add pydantic_ai.ext.temporal to API docs
DouweM Aug 8, 2025
22c780c
Skip testing flaky example
DouweM Aug 8, 2025
19eca40
Add a bunch of tests
DouweM Aug 8, 2025
e4f7f33
Only include temporal modules when temporalio is available
DouweM Aug 8, 2025
46c1f79
Run all Temporal tests in the same xdist group (process)
DouweM Aug 8, 2025
39c4298
Uninstrument Pydantic AI after Temporal tests
DouweM Aug 11, 2025
e47d4a0
Merge branch 'main' into temporal-agent
DouweM Aug 11, 2025
b2039e2
Unskip testing flaky example
DouweM Aug 11, 2025
10afbd0
Add some more tests
DouweM Aug 11, 2025
5f8ceda
Merge branch 'main' into temporal-agent
DouweM Aug 11, 2025
d1fd4ba
Add all the tests
DouweM Aug 11, 2025
9916c82
Only retry temporal activities once in tests, so CI doesn't time out
DouweM Aug 11, 2025
13f2224
Fix logfire f-string logging with '{}' inside str repr in Python 3.10
DouweM Aug 11, 2025
342ab17
More test coverage
DouweM Aug 11, 2025
e348517
Fix typecheck
DouweM Aug 11, 2025
141984e
More test coverage
DouweM Aug 11, 2025
9f59c4d
Remove unnecessary 'pragma: no cover'
DouweM Aug 11, 2025
96f24ec
Move temporal module to pydantic_ai.durable_exec
DouweM Aug 11, 2025
7e60b95
Add docs
DouweM Aug 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
- RunOutputDataT
- capture_run_messages
- InstrumentationSettings
- EventStreamHandler
3 changes: 3 additions & 0 deletions docs/api/durable_exec.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `pydantic_ai.durable_exec`

::: pydantic_ai.durable_exec.temporal
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Pydantic AI is still pre-version 1, so breaking changes will occur, however:
!!! note
Here's a filtered list of the breaking changes for each version to help you upgrade Pydantic AI.

### v0.7.0 (2025-08-08)
### v0.7.0 (2025-08-11)

See [#2458](https://github.com/pydantic/pydantic-ai/pull/2458) - `pydantic_ai.models.StreamedResponse` now yields a `FinalResultEvent` along with the existing `PartStartEvent` and `PartDeltaEvent`. If you're using `pydantic_ai.direct.model_request_stream` or `pydantic_ai.direct.model_request_stream_sync`, you may need to update your code to account for this.

Expand Down
229 changes: 229 additions & 0 deletions docs/temporal.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ nav:
- builtin-tools.md
- common-tools.md
- retries.md
- temporal.md
- MCP:
- mcp/index.md
- mcp/client.md
Expand Down Expand Up @@ -75,6 +76,7 @@ nav:
- api/toolsets.md
- api/builtin_tools.md
- api/common_tools.md
- api/durable_exec.md
- api/output.md
- api/result.md
- api/messages.md
Expand Down
72 changes: 42 additions & 30 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
'InstrumentationSettings',
'WrapperAgent',
'AbstractAgent',
'EventStreamHandler',
)


Expand Down Expand Up @@ -593,12 +594,7 @@ async def main():
run_step=state.run_step,
)

toolset = self._get_toolset(additional=toolsets)

if output_toolset is not None:
if self._prepare_output_tools:
output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools)
toolset = CombinedToolset([output_toolset, toolset])
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)

async with toolset:
# This will raise errors for any name conflicts
Expand Down Expand Up @@ -1240,48 +1236,64 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T:
return deps

def _get_toolset(
self, additional: Sequence[AbstractToolset[AgentDepsT]] | None = None
self,
output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET,
additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
) -> AbstractToolset[AgentDepsT]:
"""Get the combined toolset containing function tools registered directly to the agent and user-provided toolsets including MCP servers.
"""Get the complete toolset.

Args:
additional: Additional toolsets to add.
output_toolset: The output toolset to use instead of the one built at agent construction time.
additional_toolsets: Additional toolsets to add, unless toolsets have been overridden.
"""
if some_tools := self._override_tools.get():
function_toolset = _AgentFunctionToolset(some_tools.value, max_retries=self._max_tool_retries)
else:
function_toolset = self._function_toolset
toolsets = self.toolsets
# Don't add additional toolsets if the toolsets have been overridden
if additional_toolsets and self._override_toolsets.get() is None:
toolsets = [*toolsets, *additional_toolsets]

if some_user_toolsets := self._override_toolsets.get():
user_toolsets = some_user_toolsets.value
else:
# Copy the dynamic toolsets to ensure each run has its own instances
dynamic_toolsets = [dataclasses.replace(toolset) for toolset in self._dynamic_toolsets]
user_toolsets = [*self._user_toolsets, *dynamic_toolsets, *(additional or [])]
toolset = CombinedToolset(toolsets)

if user_toolsets:
toolset = CombinedToolset([function_toolset, *user_toolsets])
else:
toolset = function_toolset
# Copy the dynamic toolsets to ensure each run has its own instances
def copy_dynamic_toolsets(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
if isinstance(toolset, DynamicToolset):
return dataclasses.replace(toolset)
else:
return toolset

toolset = toolset.visit_and_replace(copy_dynamic_toolsets)

if self._prepare_tools:
toolset = PreparedToolset(toolset, self._prepare_tools)

output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset
if output_toolset is not None:
if self._prepare_output_tools:
output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools)
toolset = CombinedToolset([output_toolset, toolset])

return toolset

@property
def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
"""All toolsets registered on the agent, including a function toolset holding tools that were registered on the agent directly.

If a `prepare_tools` function was configured on the agent, this will contain just a `PreparedToolset` wrapping the original toolsets.

Output tools are not included.
"""
toolset = self._get_toolset()
if isinstance(toolset, CombinedToolset):
return toolset.toolsets
toolsets: list[AbstractToolset[AgentDepsT]] = []

if some_tools := self._override_tools.get():
function_toolset = _AgentFunctionToolset(some_tools.value, max_retries=self._max_tool_retries)
else:
return [toolset]
function_toolset = self._function_toolset
toolsets.append(function_toolset)

if some_user_toolsets := self._override_toolsets.get():
user_toolsets = some_user_toolsets.value
else:
user_toolsets = [*self._user_toolsets, *self._dynamic_toolsets]
toolsets.extend(user_toolsets)

return toolsets

def _prepare_output_schema(
self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile
Expand Down Expand Up @@ -1369,7 +1381,7 @@ async def run_mcp_servers(
class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
@property
def id(self) -> str:
return '<agent>' # pragma: no cover
return '<agent>'

@property
def label(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/agent/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ async def main():
usage=usage,
infer_name=infer_name,
toolsets=toolsets,
) as result:
yield result
) as run:
yield run

@contextmanager
def override(
Expand Down
Empty file.
80 changes: 80 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

import warnings
from collections.abc import Sequence
from dataclasses import replace
from typing import Any, Callable

from temporalio.client import ClientConfig, Plugin as ClientPlugin
from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter
from temporalio.converter import DefaultPayloadConverter
from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner

from ...exceptions import UserError
from ._agent import TemporalAgent
from ._logfire import LogfirePlugin
from ._run_context import TemporalRunContext, TemporalRunContextWithDeps
from ._toolset import TemporalWrapperToolset

__all__ = [
'TemporalAgent',
'PydanticAIPlugin',
'LogfirePlugin',
'AgentPlugin',
'TemporalRunContext',
'TemporalRunContextWithDeps',
'TemporalWrapperToolset',
]


class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
"""Temporal client and worker plugin for Pydantic AI."""

def configure_client(self, config: ClientConfig) -> ClientConfig:
if (data_converter := config.get('data_converter')) and data_converter.payload_converter_class not in (
DefaultPayloadConverter,
PydanticPayloadConverter,
):
warnings.warn( # pragma: no cover
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
)

config['data_converter'] = pydantic_data_converter
return super().configure_client(config)

def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType]
if isinstance(runner, SandboxedWorkflowRunner): # pragma: no branch
config['workflow_runner'] = replace(
runner,
restrictions=runner.restrictions.with_passthrough_modules(
'pydantic_ai',
'logfire',
'rich',
'httpx',
# Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize
'attrs',
# Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize
'numpy',
'pandas',
),
)

# pydantic_ai.exceptions.UserError is not retryable
config['workflow_failure_exception_types'] = [*config.get('workflow_failure_exception_types', []), UserError] # pyright: ignore[reportUnknownMemberType]

return super().configure_worker(config)


class AgentPlugin(WorkerPlugin):
"""Temporal worker plugin for a specific Pydantic AI agent."""

def __init__(self, agent: TemporalAgent[Any, Any]):
self.agent = agent

def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType]
# Activities are checked for name conflicts by Temporal.
config['activities'] = [*activities, *self.agent.temporal_activities]
return super().configure_worker(config)
Loading
Loading