1- import json
2- import subprocess
31from dataclasses import dataclass
42from typing import Any , Dict
53
6- from jinja2 import Template
7-
8- from debug_gym .agents .base_agent import (
9- LLM ,
10- AgentArgs ,
11- BaseAgent ,
12- Environment ,
13- register_agent ,
14- )
4+ from debug_gym .agents .base_agent import LLM , AgentArgs , BaseAgent , register_agent
155from debug_gym .agents .history_tracker import HistoryTracker
166from debug_gym .gym .envs .env import EnvInfo
17- from debug_gym .gym .utils import filter_non_utf8
18- from debug_gym .llms .utils import trim
19-
20-
21- def build_history_prompt (
22- history : HistoryTracker , llm : LLM , reset_prompt_history_after_rewrite : bool = False
23- ):
24- env_observations , llm_responses = history .get ()
25- latest_rewrite_step = 0
26- # Find the latest rewrite step if reset_prompt_history_after_rewrite
27- if reset_prompt_history_after_rewrite :
28- for i , obs in enumerate (env_observations ):
29- if obs .rewrite_counter == env_observations [- 1 ].rewrite_counter :
30- latest_rewrite_step = i
31- break
32-
33- env_observations = env_observations [latest_rewrite_step :]
34- llm_responses = llm_responses [latest_rewrite_step :]
35-
36- messages = []
37- for obs , response in zip (env_observations , llm_responses ):
38- # environment observation
39- messages .extend (
40- llm .convert_observation_to_message (
41- obs .step_observation .observation ,
42- obs .action_tool_call .id if obs .action_tool_call else None ,
43- obs .action_tool_call .name if obs .action_tool_call else None ,
44- )
45- )
46- # llm response
47- messages .extend (llm .convert_response_to_message (response ))
48- return messages
497
508
519@dataclass
@@ -67,7 +25,6 @@ def __init__(
6725 * args ,
6826 ** kwargs ,
6927 ):
70-
7128 agent_args = (
7229 FroggyAgentArgs .from_dict (agent_args )
7330 if isinstance (agent_args , dict )
@@ -80,6 +37,7 @@ def build_history_prompt(self):
8037 self .history ,
8138 self .llm ,
8239 self .args .reset_prompt_history_after_rewrite ,
40+ history_cutoff = self .args .memory_size ,
8341 )
8442 return messages
8543
@@ -159,3 +117,40 @@ def _default_system_prompt(self, info) -> str:
159117 system_prompt_dict ["Shortcut features" ] = shortcut_features
160118
161119 return self .to_pretty_json (system_prompt_dict )
120+
121+
122+ def build_history_prompt (
123+ history : HistoryTracker ,
124+ llm : LLM ,
125+ reset_prompt_history_after_rewrite : bool = False ,
126+ history_cutoff : int = None ,
127+ ):
128+ env_observations , llm_responses = history .get ()
129+ if history_cutoff is not None :
130+ env_observations = env_observations [- history_cutoff :]
131+ llm_responses = llm_responses [- history_cutoff :]
132+
133+ latest_rewrite_step = 0
134+ # Find the latest rewrite step if reset_prompt_history_after_rewrite
135+ if reset_prompt_history_after_rewrite :
136+ for i , obs in enumerate (env_observations ):
137+ if obs .rewrite_counter == env_observations [- 1 ].rewrite_counter :
138+ latest_rewrite_step = i
139+ break
140+
141+ env_observations = env_observations [latest_rewrite_step :]
142+ llm_responses = llm_responses [latest_rewrite_step :]
143+
144+ messages = []
145+ for obs , response in zip (env_observations , llm_responses ):
146+ # llm response
147+ messages .append (llm .convert_response_to_message (response ))
148+ # environment observation
149+ messages .append (
150+ llm .convert_observation_to_message (
151+ obs .step_observation .observation ,
152+ obs .action_tool_call .id if obs .action_tool_call else None ,
153+ obs .action_tool_call .name if obs .action_tool_call else None ,
154+ )
155+ )
156+ return messages
0 commit comments