diff --git a/src/deepagents/graph.py b/src/deepagents/graph.py index 636d4f7..ea5c815 100644 --- a/src/deepagents/graph.py +++ b/src/deepagents/graph.py @@ -13,7 +13,9 @@ from deepagents.interrupt import create_interrupt_hook, ToolInterruptConfig from langgraph.types import Checkpointer from langgraph.prebuilt import create_react_agent +from pydantic import BaseModel +StructuredResponseSchema = Union[dict, type[BaseModel]] StateSchema = TypeVar("StateSchema", bound=DeepAgentState) StateSchemaType = Type[StateSchema] @@ -41,6 +43,9 @@ def _agent_builder( config_schema: Optional[Type[Any]] = None, checkpointer: Optional[Checkpointer] = None, post_model_hook: Optional[Callable] = None, + response_format: Optional[ + Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]] + ] = None, is_async: bool = False, ): prompt = instructions + base_prompt @@ -103,6 +108,7 @@ def _agent_builder( post_model_hook=selected_post_model_hook, config_schema=config_schema, checkpointer=checkpointer, + response_format=response_format, ) @@ -117,6 +123,9 @@ def create_deep_agent( config_schema: Optional[Type[Any]] = None, checkpointer: Optional[Checkpointer] = None, post_model_hook: Optional[Callable] = None, + response_format: Optional[ + Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]] + ] = None, ): """Create a deep agent. @@ -155,6 +164,7 @@ def create_deep_agent( checkpointer=checkpointer, post_model_hook=post_model_hook, is_async=False, + response_format=response_format, ) @@ -169,6 +179,9 @@ def async_create_deep_agent( config_schema: Optional[Type[Any]] = None, checkpointer: Optional[Checkpointer] = None, post_model_hook: Optional[Callable] = None, + response_format: Optional[ + Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]] + ] = None, ): """Create a deep agent. @@ -206,5 +219,6 @@ def async_create_deep_agent( config_schema=config_schema, checkpointer=checkpointer, post_model_hook=post_model_hook, + response_format=response_format, is_async=True, ) diff --git a/src/deepagents/state.py b/src/deepagents/state.py index 844c899..8067454 100644 --- a/src/deepagents/state.py +++ b/src/deepagents/state.py @@ -1,6 +1,5 @@ from langgraph.prebuilt.chat_agent_executor import AgentState -from typing import NotRequired, Annotated -from typing import Literal +from typing import Any, NotRequired, Annotated, Literal from typing_extensions import TypedDict @@ -23,3 +22,4 @@ def file_reducer(l, r): class DeepAgentState(AgentState): todos: NotRequired[list[Todo]] files: Annotated[NotRequired[dict[str, str]], file_reducer] + structured_response: NotRequired[Any]