diff --git a/src/deepagents/sub_agent.py b/src/deepagents/sub_agent.py index 71c038bb..e36bd5d3 100644 --- a/src/deepagents/sub_agent.py +++ b/src/deepagents/sub_agent.py @@ -9,7 +9,7 @@ from langchain.chat_models import init_chat_model from typing import Annotated, NotRequired, Any, Union, Optional, Callable from langgraph.types import Command -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableConfig from langgraph.prebuilt import InjectedState @@ -106,6 +106,7 @@ def _create_task_tool( async def task( description: str, subagent_type: str, + special_config_param: RunnableConfig, state: Annotated[DeepAgentState, InjectedState], tool_call_id: Annotated[str, InjectedToolCallId], ): @@ -113,7 +114,12 @@ async def task( return f"Error: invoked agent of type {subagent_type}, the only allowed types are {[f'`{k}`' for k in agents]}" sub_agent = agents[subagent_type] state["messages"] = [{"role": "user", "content": description}] - result = await sub_agent.ainvoke(state) + if special_config_param.get('configurable').get("stream_mode") == "stream": + result = None + async for chunk in sub_agent.astream(input=state, config=special_config_param): + result = chunk + else: + result = await sub_agent.ainvoke(input=state, config=special_config_param) return Command( update={ "files": result.get("files", {}), @@ -148,6 +154,7 @@ def _create_sync_task_tool( def task( description: str, subagent_type: str, + special_config_param: RunnableConfig, state: Annotated[DeepAgentState, InjectedState], tool_call_id: Annotated[str, InjectedToolCallId], ): @@ -155,7 +162,7 @@ def task( return f"Error: invoked agent of type {subagent_type}, the only allowed types are {[f'`{k}`' for k in agents]}" sub_agent = agents[subagent_type] state["messages"] = [{"role": "user", "content": description}] - result = sub_agent.invoke(state) + result = sub_agent.invoke(input=state, config=special_config_param) return Command( update={ "files": result.get("files", {}),