diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index b84e538f2a..259e75baaf 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -2919,6 +2919,8 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa actor_manager, checkpoint_state, ) + except Exception as e: + logger.error(f"Error in run_training: {e}", exc_info=True) finally: cleanup_training_resources( stop_event, executor, [inference_results_Q, param_prompt_Q, evaluation_inference_results_Q], actor_manager diff --git a/open_instruct/vllm_utils3.py b/open_instruct/vllm_utils3.py index 279b81345d..f68bd1b7b1 100644 --- a/open_instruct/vllm_utils3.py +++ b/open_instruct/vllm_utils3.py @@ -15,7 +15,6 @@ """This file is copied from https://github.com/OpenRLHF/OpenRLHF""" -import copy import os import queue import time @@ -43,7 +42,7 @@ from vllm.v1.core import kv_cache_utils from open_instruct import logger_utils -from open_instruct.queue_types import GenerationResult, RequestInfo, TokenStatistics +from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics from open_instruct.tool_utils.tool_vllm import MaxCallsExceededTool, Tool from open_instruct.utils import ray_get_with_progress @@ -93,7 +92,7 @@ def _handle_output(output, tools, tracking, sampling_params, max_tool_calls, exe if not tools: return output - assert len(output.outputs) <= 1 # In tool mode, sampling_params.n == 1 + assert len(output.outputs) <= 1, f"{len(output.outputs)=}" # In tool mode, sampling_params.n == 1 o = output.outputs[0] # Update concatenated outputs @@ -203,7 +202,6 @@ def _process_outputs_with_tools( def _finalize_outputs(outputs, tracking, dataset_index, tools, token_statistics=None, start_time=None): """Prepare final outputs based on whether tools were used.""" if not tools: - outputs.sort(key=lambda x: int(x.request_id.split("_")[-1])) return _process_outputs( outputs, dataset_index=dataset_index, token_statistics=token_statistics, start_time=start_time ) @@ -223,14 +221,14 @@ def _finalize_outputs(outputs, tracking, dataset_index, tools, token_statistics= # Merge n completions into the same outputs merged_outputs = {} for req_id in tracking["concat_outputs"]: - real_req_id, _ = req_id.split("-") + real_req_id = "_".join(req_id.split("_")[:-1]) if real_req_id not in merged_outputs: merged_outputs[real_req_id] = tracking["concat_outputs"][req_id] else: merged_outputs[real_req_id].outputs.append(tracking["concat_outputs"][req_id].outputs[0]) final_outputs = sorted( - merged_outputs.values(), key=lambda x: (int(x.request_id.split("-")[0]), int(x.request_id.split("-")[1])) + merged_outputs.values(), key=lambda x: (int(x.request_id.split("_")[1]), int(x.request_id.split("_")[2])) ) return _process_outputs_with_tools( @@ -317,6 +315,32 @@ def init_process_group( return pg +def add_request(request: PromptRequest, llm_engine: vllm.LLMEngine, tools, request_metadata: dict): + """Add a request to the LLM engine.""" + prefix = "eval" if request.is_eval else "train" + + for batch_idx, prompt in enumerate(request.prompts): + request_id = f"{prefix}_{request.training_step}_{batch_idx}" + sampling_params = request.generation_config.clone() + sampling_params.n = 1 # Use n=1 for tool processing + request_metadata[request_id] = { + "is_eval": request.is_eval, + "dataset_index": request.dataset_index[batch_idx], + "training_step": request.training_step, + "sampling_params": sampling_params, + "prompt_tokens": len(prompt), + "start_time": time.perf_counter(), + } + + tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=request_id) + + for j in range(request.generation_config.n): + sub_sampling_params = sampling_params.clone() # Already has n=1 + if request.generation_config.seed is not None: + sub_sampling_params.seed = request.generation_config.seed + j + llm_engine.add_request(f"{request_id}_{j}", tokens_prompt, sub_sampling_params) + + class LLMRayActor: """Ray actor for LLM generation with optional tool support.""" @@ -384,6 +408,15 @@ def _should_stop(self) -> bool: ray.cancel(should_stop_ref) return self._should_stop_value + def _insert_result_to_queue(self, result, is_eval: bool): + """Insert result into the appropriate queue with error handling.""" + try: + results_queue = self.eval_results_queue if is_eval else self.results_queue + results_queue.put(result, timeout=10) + except queue.Full: + queue_name = "eval" if is_eval else "train" + self.logger.warning(f"{queue_name} results queue is full, discarding result.") + def process_from_queue(self, timeout: float = 60.0): """Run generation loop using LLMEngine directly, with optional tool support. @@ -401,37 +434,20 @@ def process_from_queue(self, timeout: float = 60.0): result = self._process_request(request) - try: - if request.is_eval: - self.eval_results_queue.put(result, timeout=10) - else: - self.results_queue.put(result, timeout=10) - return 1 # Successfully processed one request - except queue.Full: - self.logger.warning("Results queue is full, discarding result.") - return 0 + self._insert_result_to_queue(result, is_eval=request.is_eval) + return 1 def _process_request(self, request): """Unified processing for both tool and non-tool generation.""" - prompts = request.prompts - sampling_params = request.generation_config - start_time = request.start_time - self.logger.info(f"[LLMRayActor] Processing request with {len(prompts)} prompts, tools={bool(self.tools)}") + self.logger.info( + f"[LLMRayActor] Processing request with {len(request.prompts)} prompts, tools={bool(self.tools)}" + ) - if self.tools: - # Need n=1 for individual tool tracking - sampling_params = copy.deepcopy(sampling_params) - original_n = request.generation_config.n - sampling_params.n = 1 - tracking = _init_tool_tracking() - tokenizer = self.llm_engine.tokenizer - else: - original_n = 1 - tracking = None - tokenizer = None + tracking = _init_tool_tracking() if self.tools else None + tokenizer = self.llm_engine.tokenizer - self._add_initial_requests(prompts, sampling_params, original_n, request.training_step) + add_request(request, self.llm_engine, self.tools, request_metadata=self.request_metadata) outputs = [] iteration = 0 @@ -441,18 +457,19 @@ def _process_request(self, request): # Poll tool futures first (matching ToolUseLLM order) if tracking and tracking.get("pending_tool_futures"): - self._poll_tool_futures(tracking, sampling_params, tokenizer) + outputs.extend(self._poll_tool_futures(tracking, tokenizer)) # Process engine steps - ONLY if there are unfinished requests (matching ToolUseLLM) if self.llm_engine.has_unfinished_requests(): - step_outputs = list(self.llm_engine.step()) + step_outputs = [o for o in self.llm_engine.step() if o.finished] for output in step_outputs: - if output.finished: - result = _handle_output( - output, self.tools, tracking, sampling_params, self.max_tool_calls, self.executor - ) - if result is not None: - outputs.append(result) + self.logger.info(f"{len(output.outputs)=}") + result = _handle_output( + output, self.tools, tracking, request.generation_config, self.max_tool_calls, self.executor + ) + # Result is None when we do more tool processing. + if result is not None: + outputs.append(result) # Check termination condition (matching ToolUseLLM exactly) pending_count = len(tracking["pending_tool_futures"]) if tracking else 0 @@ -465,23 +482,40 @@ def _process_request(self, request): total_generation_tokens = 0 earliest_start_time = float("inf") + # Now, we combine outputs: + combined_outputs = defaultdict(list) for output in outputs: - request_id = output.request_id - if request_id in self.request_metadata: - metadata = self.request_metadata[request_id] - total_prompt_tokens += metadata["prompt_tokens"] - earliest_start_time = min(earliest_start_time, metadata["start_time"]) - + # Remove the sub_idx. + request_id = "_".join(output.request_id.split("_")[:-1]) + combined_outputs[request_id].append(output) + # Preserve original order from request.dataset_index + prefix = "eval" if request.is_eval else "train" + # request_id is batch_num _ training_step _ within_batch_idx _ repetition_idx. + # we order by within_batch_idx. + ordered_ids = [f"{prefix}_{request.training_step}_{batch_idx}" for batch_idx in range(len(request.prompts))] + final_outputs = [] + for request_id in ordered_ids: + outs = combined_outputs[request_id] + assert len(outs) == request.generation_config.n, f"{len(outs)=} != {request.generation_config.n=}" + final_outputs.append( + vllm.RequestOutput( + request_id=request_id, + prompt=outs[0].prompt, + prompt_token_ids=outs[0].prompt_token_ids, + prompt_logprobs=outs[0].prompt_logprobs, + outputs=[completion for out in outs for completion in out.outputs], + finished=outs[0].finished, + ) + ) + metadata = self.request_metadata.pop(request_id) + total_prompt_tokens += metadata["prompt_tokens"] + earliest_start_time = min(earliest_start_time, metadata["start_time"]) + for output in outs: for completion in output.outputs: total_generation_tokens += len(completion.token_ids) - generation_time = end_time - earliest_start_time - - for output in outputs: - self.request_metadata.pop(output.request_id, None) - result = _finalize_outputs( - outputs, + final_outputs, tracking, request.dataset_index, self.tools, @@ -490,33 +524,17 @@ def _process_request(self, request): num_response_tokens=total_generation_tokens, generation_time=generation_time, ), - start_time=start_time, + start_time=request.start_time, ) return result - def _add_initial_requests(self, prompts, sampling_params, n_samples, training_step): - """Add initial requests to the engine.""" - for i, prompt in enumerate(prompts): - if self.tools: - # Create individual requests for each sample when using tools - for j in range(n_samples): - request_id = f"{training_step}_{i}-{j}" - self.request_metadata[request_id] = {"start_time": time.time(), "prompt_tokens": len(prompt)} - tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=f"{training_step}_{i}") - self.llm_engine.add_request(request_id, tokens_prompt, sampling_params) - else: - # Standard request format for non-tool mode - request_id = f"batch_{training_step}_{i}" - self.request_metadata[request_id] = {"start_time": time.time(), "prompt_tokens": len(prompt)} - tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=request_id) - self.llm_engine.add_request(request_id, tokens_prompt, sampling_params) - - def _poll_tool_futures(self, tracking, sampling_params, tokenizer): + def _poll_tool_futures(self, tracking, tokenizer): """Poll and handle completed tool executions.""" if not self.tools or not tracking["pending_tool_futures"]: - return + return [] dict_keys_to_delete = [] + completed_outputs = [] for req_id, (future, last_o, last_output) in tracking["pending_tool_futures"].items(): if not future.done(): @@ -525,6 +543,11 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer): # Tool future is done, process it tool_result = future.result() # Get the tool result + # Get sampling params from request metadata for this request + # Extract the base request ID by removing the sub-request suffix + base_req_id = "_".join(req_id.split("_")[:-1]) + sampling_params = self.request_metadata[base_req_id]["sampling_params"] + last_prompt_token_ids = last_output.prompt_token_ids last_token_ids = last_o.token_ids tool_output_token_ids = tokenizer.encode( @@ -559,7 +582,7 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer): can_make_new_request = can_make_new_request and new_sample_tokens > 0 if can_make_new_request: - new_sampling_params = copy.deepcopy(sampling_params) + new_sampling_params = sampling_params.clone() new_sampling_params.max_tokens = new_sample_tokens try: @@ -569,12 +592,16 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer): except Exception as e: # Match original ToolUseLLM behavior - just log and continue self.logger.error(f"[_poll_tool_futures] Error adding request {req_id}: {e}") + else: + # If we can't make a new request, this tool execution is complete + completed_outputs.append(tracking["concat_outputs"][req_id]) dict_keys_to_delete.append(req_id) for req_id in dict_keys_to_delete: - if req_id in tracking["pending_tool_futures"]: - del tracking["pending_tool_futures"][req_id] + tracking["pending_tool_futures"].pop(req_id, None) + + return completed_outputs def init_process_group( self, diff --git a/scripts/train/debug/large_test_script.sh b/scripts/train/debug/large_test_script.sh index 246bcce8f6..79704058e1 100755 --- a/scripts/train/debug/large_test_script.sh +++ b/scripts/train/debug/large_test_script.sh @@ -12,7 +12,7 @@ uv run python mason.py \ --priority urgent \ --preemptible \ --num_nodes 2 \ - --description "rlvr ace fn and og ocr stdio from base with perf penalty" \ + --description "Large (multi-node) test script." \ --max_retries 0 \ --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ --budget ai2/oe-adapt \ @@ -39,7 +39,7 @@ uv run python mason.py \ --stop_strings "" \ --non_stop_penalty False \ --temperature 1.0 \ - --verbose False \ + --verbose False \ --ground_truths_key ground_truth \ --sft_messages_key messages \ --total_episodes 10_000 \ diff --git a/scripts/train/debug/single_gpu_integration_test.sh b/scripts/train/debug/single_gpu_integration_test.sh index 39839de27c..2da00af7b5 100755 --- a/scripts/train/debug/single_gpu_integration_test.sh +++ b/scripts/train/debug/single_gpu_integration_test.sh @@ -11,6 +11,7 @@ uv run python mason.py \ --cluster ai2/augusta-google-1 \ --cluster ai2/saturn-cirrascale \ --image "$BEAKER_IMAGE" \ + --description "Single GPU on Beaker integration test." \ --pure_docker_mode \ --workspace ai2/open-instruct-dev \ --priority high \ diff --git a/scripts/train/debug/single_gpu_on_beaker.sh b/scripts/train/debug/single_gpu_on_beaker.sh index edfb0da55f..dadd9099a7 100755 --- a/scripts/train/debug/single_gpu_on_beaker.sh +++ b/scripts/train/debug/single_gpu_on_beaker.sh @@ -11,6 +11,7 @@ uv run python mason.py \ --cluster ai2/saturn-cirrascale \ --cluster ai2/ceres-cirrascale \ --image "$BEAKER_IMAGE" \ + --description "Single GPU on Beaker test script." \ --pure_docker_mode \ --workspace ai2/open-instruct-dev \ --priority urgent \ diff --git a/scripts/train/debug/tool_grpo_fast.sh b/scripts/train/debug/tool_grpo_fast.sh index c319992970..958d16b0da 100755 --- a/scripts/train/debug/tool_grpo_fast.sh +++ b/scripts/train/debug/tool_grpo_fast.sh @@ -14,6 +14,7 @@ uv run python mason.py \ --cluster ai2/augusta-google-1 \ --cluster ai2/saturn-cirrascale \ --image "$BEAKER_IMAGE" \ + --description "Single GPU on Beaker with tool use test script." \ --pure_docker_mode \ --workspace ai2/tulu-thinker \ --priority high \