Skip to content

Commit 38b755e

Browse files
committed
Fixing tests
1 parent 353ceb3 commit 38b755e

File tree

10 files changed

+593
-562
lines changed

10 files changed

+593
-562
lines changed

debug_gym/agents/base_agent.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
import json
22
import os
3-
import subprocess
43
import uuid
5-
from collections import namedtuple
6-
from copy import copy
74
from dataclasses import MISSING, asdict, dataclass, field, fields
85
from typing import Any, Dict
96

107
import numpy as np
118
from jinja2 import Environment, Template
129

13-
from debug_gym.agents.history_tracker import HistoryTracker, build_history_prompt
10+
from debug_gym.agents.history_tracker import HistoryTracker
1411
from debug_gym.gym.envs.env import EnvInfo, RepoEnv
1512
from debug_gym.gym.utils import filter_non_utf8
1613
from debug_gym.llms.base import LLM
@@ -142,15 +139,13 @@ def _load_system_prompt_template(self) -> Template | None:
142139
"""Load system prompt template from config if specified and register custom filters.
143140
If no template is specified, return None.
144141
"""
145-
system_prompt_template_file = self.args.system_prompt_template_file
146-
if system_prompt_template_file:
147-
if not os.path.isfile(system_prompt_template):
148-
error_msg = (
149-
f"System prompt template file `{system_prompt_template}` not found."
150-
)
142+
system_prompt_template = None
143+
if self.args.system_prompt_template_file is not None:
144+
if not os.path.isfile(self.args.system_prompt_template_file):
145+
error_msg = f"System prompt template file `{self.args.system_prompt_template_file}` not found."
151146
self.logger.error(error_msg)
152147
raise FileNotFoundError(error_msg)
153-
with open(system_prompt_template, "r") as f:
148+
with open(self.args.system_prompt_template_file, "r") as f:
154149
system_prompt_template = f.read()
155150

156151
system_prompt_template = (
@@ -162,19 +157,20 @@ def _load_system_prompt_template(self) -> Template | None:
162157
env.filters["to_pretty_json"] = self.to_pretty_json
163158
env.filters["trim_message"] = self.trim_message
164159
return env.from_string(system_prompt_template)
160+
165161
return None
166162

167163
def _load_instance_prompt_template(self) -> Template | None:
168164
"""Load instance prompt template from config if specified and register custom filters.
169165
If no template is specified, return None.
170166
"""
171-
instance_prompt_template_file = self.args.instance_prompt_template_file
172-
if instance_prompt_template_file:
173-
if not os.path.isfile(instance_prompt_template_file):
174-
error_msg = f"Instance prompt template file `{instance_prompt_template_file}` not found."
167+
instance_prompt_template = None
168+
if self.args.instance_prompt_template_file is not None:
169+
if not os.path.isfile(self.args.instance_prompt_template_file):
170+
error_msg = f"Instance prompt template file `{self.args.instance_prompt_template_file}` not found."
175171
self.logger.error(error_msg)
176172
raise FileNotFoundError(error_msg)
177-
with open(instance_prompt_template_file, "r") as f:
173+
with open(self.args.instance_prompt_template_file, "r") as f:
178174
instance_prompt_template = f.read()
179175

180176
instance_prompt_template = (

debug_gym/agents/froggy_agent.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,9 @@
1-
import json
2-
import subprocess
31
from dataclasses import dataclass
42
from typing import Any, Dict
53

6-
from jinja2 import Template
7-
8-
from debug_gym.agents.base_agent import (
9-
LLM,
10-
AgentArgs,
11-
BaseAgent,
12-
Environment,
13-
register_agent,
14-
)
4+
from debug_gym.agents.base_agent import LLM, AgentArgs, BaseAgent, register_agent
155
from debug_gym.agents.history_tracker import HistoryTracker
166
from debug_gym.gym.envs.env import EnvInfo
17-
from debug_gym.gym.utils import filter_non_utf8
18-
from debug_gym.llms.utils import trim
19-
20-
21-
def build_history_prompt(
22-
history: HistoryTracker, llm: LLM, reset_prompt_history_after_rewrite: bool = False
23-
):
24-
env_observations, llm_responses = history.get()
25-
latest_rewrite_step = 0
26-
# Find the latest rewrite step if reset_prompt_history_after_rewrite
27-
if reset_prompt_history_after_rewrite:
28-
for i, obs in enumerate(env_observations):
29-
if obs.rewrite_counter == env_observations[-1].rewrite_counter:
30-
latest_rewrite_step = i
31-
break
32-
33-
env_observations = env_observations[latest_rewrite_step:]
34-
llm_responses = llm_responses[latest_rewrite_step:]
35-
36-
messages = []
37-
for obs, response in zip(env_observations, llm_responses):
38-
# environment observation
39-
messages.extend(
40-
llm.convert_observation_to_message(
41-
obs.step_observation.observation,
42-
obs.action_tool_call.id if obs.action_tool_call else None,
43-
obs.action_tool_call.name if obs.action_tool_call else None,
44-
)
45-
)
46-
# llm response
47-
messages.extend(llm.convert_response_to_message(response))
48-
return messages
497

508

519
@dataclass
@@ -67,7 +25,6 @@ def __init__(
6725
*args,
6826
**kwargs,
6927
):
70-
7128
agent_args = (
7229
FroggyAgentArgs.from_dict(agent_args)
7330
if isinstance(agent_args, dict)
@@ -80,6 +37,7 @@ def build_history_prompt(self):
8037
self.history,
8138
self.llm,
8239
self.args.reset_prompt_history_after_rewrite,
40+
history_cutoff=self.args.memory_size,
8341
)
8442
return messages
8543

@@ -159,3 +117,40 @@ def _default_system_prompt(self, info) -> str:
159117
system_prompt_dict["Shortcut features"] = shortcut_features
160118

161119
return self.to_pretty_json(system_prompt_dict)
120+
121+
122+
def build_history_prompt(
123+
history: HistoryTracker,
124+
llm: LLM,
125+
reset_prompt_history_after_rewrite: bool = False,
126+
history_cutoff: int = None,
127+
):
128+
env_observations, llm_responses = history.get()
129+
if history_cutoff is not None:
130+
env_observations = env_observations[-history_cutoff:]
131+
llm_responses = llm_responses[-history_cutoff:]
132+
133+
latest_rewrite_step = 0
134+
# Find the latest rewrite step if reset_prompt_history_after_rewrite
135+
if reset_prompt_history_after_rewrite:
136+
for i, obs in enumerate(env_observations):
137+
if obs.rewrite_counter == env_observations[-1].rewrite_counter:
138+
latest_rewrite_step = i
139+
break
140+
141+
env_observations = env_observations[latest_rewrite_step:]
142+
llm_responses = llm_responses[latest_rewrite_step:]
143+
144+
messages = []
145+
for obs, response in zip(env_observations, llm_responses):
146+
# llm response
147+
messages.append(llm.convert_response_to_message(response))
148+
# environment observation
149+
messages.append(
150+
llm.convert_observation_to_message(
151+
obs.step_observation.observation,
152+
obs.action_tool_call.id if obs.action_tool_call else None,
153+
obs.action_tool_call.name if obs.action_tool_call else None,
154+
)
155+
)
156+
return messages

debug_gym/agents/history_tracker.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,33 +30,33 @@ def init(
3030

3131
def step(
3232
self,
33-
llm_response: LLMResponse,
3433
env_observation: EnvInfo,
34+
llm_response: LLMResponse,
3535
) -> None:
3636
"""llm_responses can be None since the initial state does not have prompt and response"""
37-
self.llm_responses.append(copy.deepcopy(llm_response))
3837
self.env_observations.append(copy.deepcopy(env_observation))
38+
self.llm_responses.append(copy.deepcopy(llm_response))
3939

4040
def get(self):
41-
# return the history_steps latest steps
41+
"""Returns the full history of environment observations and LLM responses."""
4242
return (
4343
self.env_observations,
4444
self.llm_responses,
4545
)
4646

47-
def json(self, game_step=None):
48-
if len(self.env_observations) == 0 and self.env_init is None:
47+
def json(self, game_step: int | None = None):
48+
if len(self.env_observations) == 0 and self.env_initial_observation is None:
4949
return {}
5050

51-
if game_step >= len(self.env_observations):
51+
# Retrieve the most recent step by default.
52+
game_step = (
53+
game_step if game_step is not None else len(self.env_observations) - 1
54+
)
55+
if game_step < 0 or game_step >= len(self.env_observations):
5256
raise ValueError(
53-
f"Invalid game_step: {game_step}. Max step: {len(self.env_observations)-1}"
57+
f"Invalid game_step: {game_step}; should be between [0, {len(self.env_observations)-1}]."
5458
)
5559

56-
if game_step is None:
57-
# retrieve the most recent step
58-
game_step = len(self.env_observations) - 1
59-
6060
if game_step == 0:
6161
# initial state
6262
json_out = {
@@ -91,7 +91,7 @@ def json(self, game_step=None):
9191
return json_out
9292

9393
def score(self):
94-
return sum([memory.score for memory in self.env_observations])
94+
return sum([obs.score for obs in self.env_observations])
9595

9696
def __len__(self):
9797
return len(self.env_observations)

debug_gym/agents/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,10 @@ def save_patch(env, problem_path: Path, logger: DebugGymLogger):
157157
logger.debug(f"Patch saved in {patch_path}")
158158

159159

160-
def save_trajectory(agent, problem: str, problem_path: Path, logger: DebugGymLogger):
160+
def save_trajectory(agent, problem_path: Path, logger: DebugGymLogger):
161161
"""Persist the agent trajectory to disk."""
162162
problem_path.mkdir(parents=True, exist_ok=True)
163-
trajectory = agent.build_trajectory(task_name=problem)
163+
trajectory = agent._build_trajectory()
164164
json_file = problem_path / "trajectory.json"
165165
with open(json_file, "w") as f:
166166
json.dump(trajectory, f, indent=4)

debug_gym/llms/openai.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,25 +191,25 @@ def parse_tool_call_response(self, response) -> ToolCall:
191191
)
192192

193193
def convert_response_to_message(self, response: LLMResponse) -> dict:
194-
response = {
194+
message = {
195195
"role": "assistant",
196196
"tool_calls": [
197197
{
198198
"type": "function",
199199
"id": response.tool.id,
200200
"function": {
201-
"name": response[0].tool.name,
202-
"arguments": json.dumps(response[0].tool.arguments),
201+
"name": response.tool.name,
202+
"arguments": json.dumps(response.tool.arguments),
203203
},
204204
},
205205
],
206-
"content": filter_non_utf8(f"{response[0].response}"),
206+
"content": filter_non_utf8(f"{response.response}"),
207207
}
208208
if response.reasoning_response:
209-
response["reasoning_content"] = filter_non_utf8(
209+
message["reasoning_content"] = filter_non_utf8(
210210
f"{response.reasoning_response}"
211211
)
212-
return response
212+
return message
213213

214214
def convert_observation_to_message(
215215
self, observation: str, last_tool_call_id=None, last_tool_call_name=None

scripts/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def run_agent(args, task_name: str, task_data: dict, config: dict):
141141
raise
142142

143143
# save trajectory
144-
save_trajectory(agent, task_name, task_path, task_logger)
144+
save_trajectory(agent, task_path, task_logger)
145145

146146
# optionally apply patch
147147
if config["save_patch"]:

0 commit comments

Comments
 (0)