Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 179 additions & 1 deletion tests/test_environment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Tests for the base Environment class."""

from unittest.mock import AsyncMock, Mock
import json
from unittest.mock import AsyncMock, Mock, patch

import httpx
import pytest
from datasets import Dataset
from openai import BadRequestError

from verifiers import Environment, Parser, Rubric, ThinkParser
from verifiers.types import ChatCompletion, Completion, GenerateOutputs, RolloutScores
Expand Down Expand Up @@ -216,6 +219,181 @@ async def test_get_model_response_completion(self, mock_openai_client):
assert hasattr(response.choices[0], "text")
mock_openai_client.completions.create.assert_called_once()

@pytest.mark.asyncio
async def test_get_model_response_chat_context_length_guard(
self, mock_openai_client
):
"""Ensure chat responses gracefully handle context length errors."""

mock_openai_client.base_url = "https://api.openai.com/v1"
env = SimpleEnvironment(
client=mock_openai_client,
model="gpt-test",
eval_dataset=Dataset.from_dict({"question": ["test"], "answer": ["test"]}),
parser=Parser(),
rubric=Rubric(),
)

error_message = (
"This model's maximum context length is 4097 tokens. "
"However, you requested 4120 tokens (4020 in the messages, 100 in the completion). "
"Please reduce the length of the messages or completion."
)
error_body = {
"error": {
"message": error_message,
"type": "invalid_request_error",
"code": "context_length_exceeded",
}
}
mock_openai_client.chat.completions.create.side_effect = BadRequestError(
message=f"Error code: 400 - {json.dumps(error_body)}",
response=httpx.Response(
status_code=400,
request=httpx.Request(
"POST", "https://api.openai.com/v1/chat/completions"
),
),
body=error_body,
)

prompt = [{"role": "user", "content": "Hello"}]
with patch.object(env.logger, "warning") as mock_warning:
response = await env.get_model_response(
prompt=prompt,
client=mock_openai_client,
model="gpt-test",
message_type="chat",
)

assert isinstance(response, ChatCompletion)
assert response.choices[0].message.content == ""
assert response.choices[0].finish_reason == "length"
assert response.usage.prompt_tokens == 4020
assert response.usage.completion_tokens == 100
assert response.usage.total_tokens == 4120
assert mock_openai_client.chat.completions.create.await_count == 1
mock_warning.assert_called_once()
warning_text = mock_warning.call_args.args[0]
assert "Context length exceeded" in warning_text
assert "requested 4120 tokens" in warning_text

@pytest.mark.asyncio
async def test_get_model_response_completion_context_length_guard_vllm(
self, mock_openai_client
):
"""Ensure completion responses handle vLLM-style context length errors."""

mock_openai_client.base_url = "http://localhost:8000/v1"
env = SimpleEnvironment(
client=mock_openai_client,
model="vllm-model",
eval_dataset=Dataset.from_dict({"prompt": ["test"], "answer": ["test"]}),
message_type="completion",
parser=Parser(),
rubric=Rubric(),
)

error_message = (
"'max_tokens' or 'max_completion_tokens' is too large: 400. "
"This model's maximum context length is 8192 tokens and your request has "
"8000 input tokens (400 > 8192 - 8000)."
)
error_body = {
"error": {
"message": error_message,
"type": "context_length_exceeded",
"code": "context_length_exceeded",
}
}
mock_openai_client.completions.create.side_effect = BadRequestError(
message=f"Error code: 400 - {json.dumps(error_body)}",
response=httpx.Response(
status_code=400,
request=httpx.Request("POST", "http://localhost:8000/v1/completions"),
),
body=error_body,
)

with patch.object(env.logger, "warning") as mock_warning:
response = await env.get_model_response(
prompt="Hello",
client=mock_openai_client,
model="vllm-model",
message_type="completion",
)

assert isinstance(response, Completion)
assert response.choices[0].text == ""
assert response.choices[0].finish_reason == "length"
assert response.usage.prompt_tokens == 8000
assert response.usage.completion_tokens == 400
assert response.usage.total_tokens == 8400
assert mock_openai_client.completions.create.await_count == 1
mock_warning.assert_called_once()
warning_text = mock_warning.call_args.args[0]
assert "Context length exceeded" in warning_text
assert "limit 8192" in warning_text
assert "over by 208 tokens" in warning_text

@pytest.mark.asyncio
async def test_get_model_response_completion_context_length_guard_vllm_prompt(
self, mock_openai_client
):
"""Prompt-only vLLM context limits should gracefully truncate."""

mock_openai_client.base_url = "http://localhost:8000/v1"
env = SimpleEnvironment(
client=mock_openai_client,
model="vllm-model",
eval_dataset=Dataset.from_dict({"prompt": ["test"], "answer": ["test"]}),
message_type="completion",
parser=Parser(),
rubric=Rubric(),
)

error_message = (
"This model's maximum context length is 8192 tokens. "
"However, your request has 9000 input tokens. "
"Please reduce the length of the input messages."
)
error_body = {
"error": {
"message": error_message,
"type": "context_length_exceeded",
"code": "context_length_exceeded",
}
}
mock_openai_client.completions.create.side_effect = BadRequestError(
message=f"Error code: 400 - {json.dumps(error_body)}",
response=httpx.Response(
status_code=400,
request=httpx.Request("POST", "http://localhost:8000/v1/completions"),
),
body=error_body,
)

with patch.object(env.logger, "warning") as mock_warning:
response = await env.get_model_response(
prompt="Hello",
client=mock_openai_client,
model="vllm-model",
message_type="completion",
)

assert isinstance(response, Completion)
assert response.choices[0].text == ""
assert response.choices[0].finish_reason == "length"
assert response.usage.prompt_tokens == 9000
assert response.usage.completion_tokens == 0
assert response.usage.total_tokens == 9000
assert mock_openai_client.completions.create.await_count == 1
mock_warning.assert_called_once()
warning_text = mock_warning.call_args.args[0]
assert "Context length exceeded" in warning_text
assert "limit 8192" in warning_text
assert "prompt 9000 tokens" in warning_text

def test_process_chat_format(self, mock_openai_client, sample_dataset):
"""Test processing chat format conversations."""
env = SimpleEnvironment(
Expand Down
2 changes: 1 addition & 1 deletion verifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
extract_hash_answer,
load_example_dataset,
)
from .utils.env_utils import load_environment
from .loader import load_environment
from .utils.logging_utils import print_prompt_completions_sample


Expand Down
20 changes: 20 additions & 0 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,15 @@
SamplingArgs,
State,
)
from verifiers.utils.env_utils import (
build_context_length_stub_response,
extract_context_length_error,
format_context_length_warning,
infer_provider_name,
)
from verifiers.utils.message_utils import cleanup_messages, sanitize_tool_calls


if TYPE_CHECKING:
from transformers.tokenization_utils_base import ( # type: ignore
PreTrainedTokenizerBase,
Expand Down Expand Up @@ -274,6 +281,19 @@ async def get_model_response(
)
return response
except Exception as e:
context_error = extract_context_length_error(e)
if context_error is not None:
provider = infer_provider_name(client)
warning_message = format_context_length_warning(
provider, model, context_error.details
)
message_text = context_error.message
if message_text and message_text not in warning_message:
warning_message = f"{warning_message} - upstream: {message_text}"
self.logger.warning(warning_message)
return build_context_length_stub_response(
message_type, model, context_error.details
)
self.logger.error(f"Error getting model response: {e} \n\nExiting...")
raise e

Expand Down
93 changes: 93 additions & 0 deletions verifiers/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import importlib
import inspect
import logging
from typing import Callable

from verifiers.envs.environment import Environment


def load_environment(env_id: str, **env_args) -> Environment:
logger = logging.getLogger("verifiers.loader")
logger.info("Loading environment: %s", env_id)

module_name = env_id.replace("-", "_")
try:
module = importlib.import_module(module_name)

if not hasattr(module, "load_environment"):
raise AttributeError(
f"Module '{module_name}' does not have a 'load_environment' function. "
"This usually means there's a package name collision. Please either:\n"
"1. Rename your environment (e.g. suffix with '-env')\n"
"2. Remove unneeded files with the same name\n"
"3. Check that you've installed the correct environment package"
)

env_load_func: Callable[..., Environment] = getattr(
module, "load_environment"
)
sig = inspect.signature(env_load_func)
defaults_info = []
for param_name, param in sig.parameters.items():
if param.default != inspect.Parameter.empty:
if isinstance(param.default, (dict, list)):
defaults_info.append(f"{param_name}={param.default}")
elif isinstance(param.default, str):
defaults_info.append(f"{param_name}='{param.default}'")
else:
defaults_info.append(f"{param_name}={param.default}")
else:
defaults_info.append(f"{param_name}=<required>")

if defaults_info:
logger.debug("Environment defaults: %s", ", ".join(defaults_info))

provided_params = set(env_args.keys()) if env_args else set()
all_params = set(sig.parameters.keys())
default_params = all_params - provided_params

if provided_params:
provided_values = [f"{name}={env_args[name]}" for name in provided_params]
logger.info("Using provided args: %s", ", ".join(provided_values))

if default_params:
default_values = []
for param_name in default_params:
param = sig.parameters[param_name]
if param.default != inspect.Parameter.empty:
if isinstance(param.default, str):
default_values.append(f"{param_name}='{param.default}'")
else:
default_values.append(f"{param_name}={param.default}")
if default_values:
logger.info("Using default args: %s", ", ".join(default_values))

env_instance = env_load_func(**env_args)

if not isinstance(env_instance, Environment):
raise TypeError(
f"Environment '{env_id}' returned {type(env_instance)} which is not a verifiers Environment"
)

logger.info("Successfully loaded environment '%s'", env_id)

return env_instance

except ImportError as error:
logger.error(
"Failed to import environment module %s for env_id %s: %s",
module_name,
env_id,
error,
)
raise ValueError(
f"Could not import '{env_id}' environment. Ensure the package for the '{env_id}' environment is installed."
) from error
except Exception as error: # noqa: BLE001 - propagate structured message
logger.error(
"Failed to load environment %s with args %s: %s",
env_id,
env_args,
error,
)
raise RuntimeError(f"Failed to load environment '{env_id}': {error}") from error
Loading
Loading