diff --git a/environments/simpleqa/simpleqa.py b/environments/simpleqa/simpleqa.py index f72b085c..70e4b400 100644 --- a/environments/simpleqa/simpleqa.py +++ b/environments/simpleqa/simpleqa.py @@ -123,16 +123,16 @@ def correct_answer_reward_func( def incorrect_answer_reward_func( prompt, completion, answer, state, **kwargs ) -> float: - judge_response = rubric.judge(prompt, completion, answer, state, **kwargs) - match = re.search(r"(A|B|C)", judge_response) + resp = list(state["judge_response"].values())[-1] + match = re.search(r"(A|B|C)", resp) result = match.group(0) if match else "C" return 1.0 if result == "B" else 0.0 def not_attempted_answer_reward_func( prompt, completion, answer, state, **kwargs ) -> float: - judge_response = rubric.judge(prompt, completion, answer, state, **kwargs) - match = re.search(r"(A|B|C)", judge_response) + resp = list(state["judge_response"].values())[-1] + match = re.search(r"(A|B|C)", resp) result = match.group(0) if match else "C" return 1.0 if result == "C" else 0.0