diff --git a/protollm/agents/builder.py b/protollm/agents/builder.py index ef134c4..8c9aea4 100644 --- a/protollm/agents/builder.py +++ b/protollm/agents/builder.py @@ -4,6 +4,12 @@ from protollm.agents.universal_agents import (chat_node, plan_node, replan_node, summary_node, supervisor_node, web_search_node) +import copy +import time +from typing import Annotated + +from langgraph.types import Command +from langgraph.prebuilt import create_react_agent class GraphBuilder: @@ -101,6 +107,61 @@ def _routing_function_planner(self, state): if state.get("response"): return END return "supervisor" + + + def create_scenario_node(self, node_name: str): + """This function creates a scenario node for the agent.""" + + def scenario_node(state: dict, config: dict) -> Command: + print(f"--------------------------------") + print(f"{node_name} agent called") + print("Current task:") + print(state["task"]) + print("--------------------------------") + + task = state["task"] + plan = state["plan"] + system_prompt = config["configurable"]["scenario_agents_info"][node_name]["system_prompt"] + tools = config["configurable"]["scenario_agents_info"][node_name]["tools"] + + llm = config["configurable"]["llm"] + + agent = create_react_agent( + llm, tools, prompt=system_prompt + ) + + task_formatted = f"""For the following plan: + {str(plan)}\n\nYou are tasked with executing: {task}.""" + + max_retries = 3 + for attempt in range(max_retries): + try: + config["configurable"]["state"] = state + agent_response = agent.invoke({"messages": [("user", task_formatted)]}) + + return Command( + update={ + "past_steps": Annotated[set, "or_"]( + {(task, agent_response["messages"][-1].content)} + ), + "nodes_calls": Annotated[set, "or_"]( + { + ( + node_name, + tuple( + (m.type, m.content) + for m in agent_response["messages"] + ), + ) + } + ), + }, + ) + + except Exception as e: + print(f"{node_name} failed: {str(e)}. Retrying ({attempt+1}/{max_retries})") + time.sleep(1.2**attempt) + return scenario_node def _build(self): """Build graph based on a non-dynamic agent skeleton""" @@ -115,11 +176,14 @@ def _build(self): workflow.add_node("web_search", web_search_node) workflow.add_edge("web_search", "replan_node") - for agent_name, node in self.conf["configurable"][ - "scenario_agent_funcs" - ].items(): + for agent_name in self.conf["configurable"][ + "scenario_agents" + ]: + node = copy.deepcopy(self.create_scenario_node(agent_name)) workflow.add_node(agent_name, node) workflow.add_edge(agent_name, "replan_node") + self.conf["configurable"]["scenario_agent_funcs"]={} + self.conf["configurable"]["scenario_agent_funcs"][agent_name] = node workflow.add_edge(START, "chat") diff --git a/protollm/agents/scenario_agent_example.py b/protollm/agents/scenario_agent_example.py new file mode 100644 index 0000000..6e575d3 --- /dev/null +++ b/protollm/agents/scenario_agent_example.py @@ -0,0 +1,56 @@ +import time +from typing import Annotated + +from langgraph.prebuilt import create_react_agent +from langgraph.types import Command + + +def playground_scenario_node(state, config: dict) -> Command: + print("--------------------------------") + print("Playground agent called") + print("Current task:") + print(state["task"]) + print("--------------------------------") + + system_prompt = config["configurable"]["additional_agents_info"]["playground_scenario_node"]["system_prompt"] + tools = config["configurable"]["additional_agents_info"]["playground_scenario_node"]["tools"] + + task = state["task"] + plan = state["plan"] + + llm = config["configurable"]["llm"] + chem_agent = create_react_agent( + llm, tools, state_modifier=system_prompt + ) + + task_formatted = f"""For the following plan: + {str(plan)}\n\nYou are tasked with executing: {task}.""" + + max_retries = 3 + for attempt in range(max_retries): + try: + config["configurable"]["state"] = state + agent_response = chem_agent.invoke({"messages": [("user", task_formatted)]}) + + return Command( + update={ + "past_steps": Annotated[set, "or_"]( + {(task, agent_response["messages"][-1].content)} + ), + "nodes_calls": Annotated[set, "or_"]( + { + ( + "playground_scenario_node", + tuple( + (m.type, m.content) + for m in agent_response["messages"] + ), + ) + } + ), + }, + ) + + except Exception as e: + print(f"Playground scenario node failed: {str(e)}. Retrying ({attempt+1}/{max_retries})") + time.sleep(1.2**attempt) \ No newline at end of file