Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,18 @@ async def get_model_response(
)
return response
except Exception as e:
self.logger.error(f"Error getting model response: {e} \n\nExiting...")
error_msg = f"Error getting model response: {e}"
# Check if this is a timeout error and provide specific guidance
if "timeout" in str(e).lower():
error_msg += (
". Consider reducing max_concurrent, increasing timeout values, "
"or checking your model server health. "
"Try reducing max_concurrent parameter, increasing async_generation_timeout in GRPOConfig, "
"verifying vLLM server is running and responsive, considering reducing max_tokens or using a smaller model, "
"and increasing system limits with 'ulimit -n 4096'."
)
self.logger.error(f"{error_msg} \n\nExiting...")
# Re-raise the exception so it can be properly handled upstream
raise e

@abstractmethod
Expand Down Expand Up @@ -406,19 +417,29 @@ async def a_generate(
reward=[],
metrics={},
)
rollouts = await self.run_rollouts(
prompts=results.prompt,
answers=results.answer,
tasks=results.task,
infos=results.info,
client=client,
model=model,
sampling_args=gen_sampling_args,
max_concurrent=max_concurrent,
**kwargs,
)

# Run rollouts with proper error handling
try:
rollouts = await self.run_rollouts(
prompts=results.prompt,
answers=results.answer,
tasks=results.task,
infos=results.info,
client=client,
model=model,
sampling_args=gen_sampling_args,
max_concurrent=max_concurrent,
**kwargs,
)
except Exception as e:
self.logger.error(f"Error during rollouts: {e}")
# Re-raise to let the calling function handle it appropriately
raise e

results.completion = [rollout[0] for rollout in rollouts]
results.state = [rollout[1] for rollout in rollouts]

# Score rollouts if requested
if score_rollouts:
rollout_scores = await self.rubric.score_rollouts(
prompts=results.prompt,
Expand All @@ -429,8 +450,10 @@ async def a_generate(
infos=results.info,
apply_weights=True,
)

results.reward = rollout_scores.reward
results.metrics = rollout_scores.metrics

return results

def generate(
Expand Down
49 changes: 37 additions & 12 deletions verifiers/trainers/async_batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,20 @@ def get_batch(self, batch_id: int, timeout: float | None = None) -> BatchResult:
pass

# Check timeout
if time.time() - start_time > timeout:
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
# Before raising timeout, check if there was an error
with self._lock:
# Clean up any pending batches that might have errored
self.pending_batches.discard(batch_id)
self.logger.error(f"Batch {batch_id} generation timed out after {timeout}s")
raise TimeoutError(
f"Batch {batch_id} generation timed out after {timeout}s"
f"Batch {batch_id} generation timed out after {timeout}s. "
f"Consider reducing max_concurrent, increasing timeout values, "
f"or checking your model server health. "
f"Try reducing max_concurrent parameter, increasing async_generation_timeout in GRPOConfig, "
f"verifying vLLM server is running and responsive, considering reducing max_tokens or using a smaller model, "
f"and increasing system limits with 'ulimit -n 4096'."
)

def get_pending_count(self) -> int:
Expand Down Expand Up @@ -264,15 +275,22 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult:
"""
# Call environment generation
self.is_generating = True
env_results = await self.env.a_generate(
request.env_inputs,
client=self.client,
model=self.model_name,
sampling_args=self.sampling_args,
score_rollouts=True,
max_concurrent=request.max_concurrent,
)
self.is_generating = False
try:
env_results = await self.env.a_generate(
request.env_inputs,
client=self.client,
model=self.model_name,
sampling_args=self.sampling_args,
score_rollouts=True,
max_concurrent=request.max_concurrent,
)
except Exception as e:
self.is_generating = False
self.logger.error(f"Error during batch generation: {e}")
# Re-raise the exception to be handled by the calling function
raise e
finally:
self.is_generating = False

# Extract all reward-related keys
all_reward_dict = {
Expand Down Expand Up @@ -399,7 +417,14 @@ async def run_eval():
eval_thread.join(timeout=self.generation_timeout)

if eval_thread.is_alive():
raise TimeoutError(f"Evaluation timed out after {self.generation_timeout}s")
raise TimeoutError(
f"Evaluation timed out after {self.generation_timeout}s. "
f"Consider reducing max_concurrent, increasing timeout values, "
f"or checking your model server health. "
f"Try reducing max_concurrent parameter, increasing async_generation_timeout in GRPOConfig, "
f"verifying vLLM server is running and responsive, considering reducing max_tokens or using a smaller model, "
f"and increasing system limits with 'ulimit -n 4096'."
)

if exception_container:
raise exception_container[0]
Expand Down
Loading