Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions tests/test_multiturn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,98 @@ async def test_responses_stored_in_state(self, mock_multiturn_env):
for response in state["responses"]:
assert hasattr(response, "choices")
assert len(response.choices) > 0

@pytest.mark.asyncio
async def test_strip_think_string_content_preserves_tail_and_tools(
self, mock_openai_client, sample_chat_dataset
):
"""Ensure only text up to </think> is removed; tool_calls and tool messages remain."""
from tests.conftest import SimpleMultiTurnEnv

env = SimpleMultiTurnEnv(
client=mock_openai_client,
model="test-model",
dataset=sample_chat_dataset,
parser=Parser(),
rubric=Rubric(),
exclude_think=True,
)

prompt = [{"role": "user", "content": "What is 2+2?"}]
state = await env.init_state(
prompt=prompt,
completion=[],
answer="",
task="default",
info={},
example_id=0,
)

assistant_msg = {
"role": "assistant",
"content": "<think>\nprivate reasoning</think>\n\nCall tool A",
"tool_calls": [
{
"id": "id1",
"type": "function",
"function": {"name": "toolA", "arguments": "{}"},
}
],
}
tool_msg = {"role": "tool", "content": "resultA", "tool_call_id": "id1"}
state["completion"].extend([assistant_msg, tool_msg])

ctx = await env.get_context_messages(state)
assert isinstance(ctx, list)

assert ctx[0] == prompt[0]
assert ctx[1]["role"] == "assistant"
assert ctx[1]["content"] == "Call tool A"
assert ctx[1].get("tool_calls") == assistant_msg["tool_calls"]
assert ctx[2] == tool_msg

@pytest.mark.asyncio
async def test_no_think_content_is_passthrough(
self, mock_openai_client, sample_chat_dataset
):
"""If no </think> present, assistant content remains unchanged."""
from tests.conftest import SimpleMultiTurnEnv

env = SimpleMultiTurnEnv(
client=mock_openai_client,
model="test-model",
dataset=sample_chat_dataset,
parser=Parser(),
rubric=Rubric(),
exclude_think=True,
)

prompt = [{"role": "user", "content": "Q"}]
state = await env.init_state(
prompt=prompt,
completion=[],
answer="",
task="default",
info={},
example_id=0,
)

assistant_msg = {
"role": "assistant",
"content": "No CoT here, proceed to tool",
"tool_calls": [
{
"id": "id3",
"type": "function",
"function": {"name": "toolC", "arguments": "{}"},
}
],
}
tool_msg = {"role": "tool", "content": "resultC", "tool_call_id": "id3"}
state["completion"].extend([assistant_msg, tool_msg])

ctx = await env.get_context_messages(state)
assert isinstance(ctx, list)
assert ctx[1]["content"] == assistant_msg["content"]
assert ctx[1].get("tool_calls") == assistant_msg["tool_calls"]
assert ctx[2] == tool_msg
Loading
Loading