8282from transformers import AutoModelForCausalLM , PreTrainedModel , PreTrainedTokenizer , get_scheduler
8383from transformers .integrations import HfDeepSpeedConfig
8484
85- from open_instruct import logger_utils , rl_utils2 , vllm_utils3
85+ from open_instruct import logger_utils , vllm_utils3
8686from open_instruct .actor_manager import ActorManager
8787from open_instruct .dataset_transformation import (
8888 GROUND_TRUTHS_KEY ,
111111 push_folder_to_hub ,
112112)
113113from open_instruct .queue_types import GenerationResult , PromptRequest , RequestInfo , TokenStatistics
114+ from open_instruct .rl_utils2 import PackedSequences , Timer , pack_sequences
114115from open_instruct .utils import (
115116 ArgumentParserPlus ,
116117 BeakerRuntimeConfig ,
@@ -149,7 +150,7 @@ class ShutdownSentinel:
149150class PackedData :
150151 """Container for packed sequences and associated metadata."""
151152
152- packed_sequences : rl_utils2 . PackedSequences
153+ packed_sequences : PackedSequences
153154 collated_data : list # Collated training data for each device
154155 metrics : dict # Training metrics
155156 responses_count : int # Number of responses
@@ -931,7 +932,7 @@ def train(
931932
932933 # Calculate the logprob of the reference policy
933934 collated_ref_logprobs = []
934- with rl_utils2 . Timer ("Inference Calculation" , noop = self .rank != 0 ):
935+ with Timer ("Inference Calculation" , noop = self .rank != 0 ):
935936 with torch .no_grad ():
936937 for i in range (len (collated_query_responses )):
937938 query_response = collated_query_responses [i ]
@@ -961,7 +962,7 @@ def train(
961962 # from the generator (note that async mode means these are a bit diff!)
962963 old_logprobs = [None for _ in range (len (collated_query_responses ))]
963964 if num_mini_batches > 1 :
964- with rl_utils2 . Timer ("Old logprobs Calculation" , noop = self .rank != 0 ):
965+ with Timer ("Old logprobs Calculation" , noop = self .rank != 0 ):
965966 with torch .no_grad ():
966967 for i in range (len (collated_query_responses )):
967968 query_response = collated_query_responses [i ]
@@ -988,7 +989,7 @@ def train(
988989
989990 local_step = 0
990991 # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch
991- with rl_utils2 . Timer ("[Training Processes] Loss calculation" , noop = self .rank != 0 ):
992+ with Timer ("[Training Processes] Loss calculation" , noop = self .rank != 0 ):
992993 kl1_stats = torch .zeros (len (collated_query_responses ))
993994 kl2_stats = torch .zeros (len (collated_query_responses ))
994995 kl3_stats = torch .zeros (len (collated_query_responses ))
@@ -1411,7 +1412,6 @@ def accumulate_inference_batches(
14111412 all_ground_truths = []
14121413 all_datasets = []
14131414 all_raw_queries = []
1414- all_indices = []
14151415 for i in tqdm (
14161416 range (num_prompts ),
14171417 total = num_prompts ,
@@ -1438,7 +1438,6 @@ def accumulate_inference_batches(
14381438 all_ground_truths .append (ground_truth )
14391439 all_datasets .append (dataset )
14401440 all_raw_queries .append (raw_query )
1441- all_indices .append (result .dataset_index )
14421441
14431442 # Combine all results into a single GenerationResult
14441443 combined_responses = []
@@ -1495,13 +1494,13 @@ def accumulate_inference_batches(
14951494 if actor_manager is not None :
14961495 ray .get (actor_manager .report_token_statistics .remote (accumulated_stats ))
14971496
1498- # Create batch with preserved dataset indices
1497+ # Note: We don't have dataset_indices here, but they're not needed for the returned batch
14991498 batch = Batch (
15001499 queries = all_queries ,
15011500 ground_truths = all_ground_truths ,
15021501 datasets = all_datasets ,
15031502 raw_queries = all_raw_queries ,
1504- indices = all_indices , # Preserve the dataset indices for MFU/MBU calculations
1503+ indices = None , # Not meaningful for combined results
15051504 )
15061505 return combined_result , batch
15071506
@@ -1520,7 +1519,7 @@ def data_preparation_thread(
15201519):
15211520 for training_step in range (resume_training_step , num_training_steps + 1 ):
15221521 # Streaming accumulation: collect results as they arrive
1523- with rl_utils2 . Timer ("🚀 [Data Preparation Thread] Getting response ids" ) as timer :
1522+ with Timer ("🚀 [Data Preparation Thread] Getting response ids" ) as timer :
15241523 result , batch = accumulate_inference_batches (
15251524 inference_results_Q ,
15261525 pending_queries_map ,
@@ -1562,14 +1561,14 @@ def data_preparation_thread(
15621561 ):
15631562 result .responses [i ].append (tokenizer .eos_token_id )
15641563 result .masks [i ].append (1 ) # never mask the eos token for
1565- with rl_utils2 . Timer ("🔥 [Data Preparation Thread] Decoding responses" , noop = True ):
1564+ with Timer ("🔥 [Data Preparation Thread] Decoding responses" , noop = True ):
15661565 decoded_responses = tokenizer .batch_decode (result .responses , skip_special_tokens = True )
15671566 decoded_queries = batch .raw_queries
15681567 stop_rate = sum (int (finish_reason == "stop" ) for finish_reason in result .finish_reasons ) / len (
15691568 result .finish_reasons
15701569 )
15711570
1572- with rl_utils2 . Timer ("💰 [Data Preparation Thread] Calculating rewards and advantages" ):
1571+ with Timer ("💰 [Data Preparation Thread] Calculating rewards and advantages" ):
15731572 scores , reward_metrics = asyncio .run (
15741573 reward_fn (
15751574 result .responses ,
@@ -1593,7 +1592,7 @@ def data_preparation_thread(
15931592 else :
15941593 raise ValueError (f"Invalid advantage normalization type: { args .advantage_normalization_type } " )
15951594
1596- with rl_utils2 . Timer ("📦 [Data Preparation Thread] Filtering sequences" ):
1595+ with Timer ("📦 [Data Preparation Thread] Filtering sequences" ):
15971596 # Here we get the max possible score for each prompt, and see how many prompts are unsolved
15981597 max_possible_score = 0
15991598 if args .apply_verifiable_reward :
@@ -1640,7 +1639,7 @@ def data_preparation_thread(
16401639 finish_reasons = [finish_reasons [i ] for i in stop_idxes ]
16411640
16421641 if args .fill_completions :
1643- with rl_utils2 . Timer ("⏱ [Data Preparation Thread] Refill completions" ):
1642+ with Timer ("⏱ [Data Preparation Thread] Refill completions" ):
16441643 current_batch_size = len (scores )
16451644 original_prompt_cnt = original_batch_size // args .num_samples_per_prompt_rollout
16461645 current_prompt_cnt = current_batch_size // args .num_samples_per_prompt_rollout
@@ -1694,8 +1693,8 @@ def data_preparation_thread(
16941693 f"({ all_zero_groups / total_groups :.1%} )"
16951694 )
16961695
1697- with rl_utils2 . Timer ("📦 [Data Preparation Thread] Packing sequences" ):
1698- packed_sequences = rl_utils2 . pack_sequences (
1696+ with Timer ("📦 [Data Preparation Thread] Packing sequences" ):
1697+ packed_sequences = pack_sequences (
16991698 queries = batch .queries ,
17001699 responses = responses ,
17011700 masks = masks ,
@@ -1716,7 +1715,7 @@ def data_preparation_thread(
17161715 # if we have less batches than world size, we need to pad out so each world is fine
17171716 # ideally, you should avoid this since its wasting computation.
17181717 if args .allow_world_padding :
1719- with rl_utils2 . Timer ("🤺 [Data Preparation Thread] Padding sequences for world size" ):
1718+ with Timer ("🤺 [Data Preparation Thread] Padding sequences for world size" ):
17201719 shortfall = args .world_size - len (packed_sequences .query_responses )
17211720 if shortfall > 0 :
17221721 logger .warning (
@@ -1738,7 +1737,7 @@ def data_preparation_thread(
17381737 packed_sequences .response_masks .append (dummy_response_mask )
17391738 packed_sequences .advantages .append (dummy_advantage )
17401739
1741- with rl_utils2 . Timer ("🔄 [Data Preparation Thread] Prepare collated data for each worker" ):
1740+ with Timer ("🔄 [Data Preparation Thread] Prepare collated data for each worker" ):
17421741 B = (
17431742 len (packed_sequences .query_responses ) // args .world_size
17441743 ) # essentially doing `drop_last=True`, which is fine.
@@ -1861,15 +1860,9 @@ def data_preparation_thread(
18611860 logger .warning (f"No responses in batch { training_step } ." )
18621861
18631862 # Put the packed sequences and metrics into the output queue
1864- # For MFU/MBU calculations, we need unique prompt lengths, not repeated ones
1865- # Use indices to identify unique prompts
1866- seen_indices = set ()
1867- unique_queries = []
1868- for idx , query in zip (batch .indices , batch .queries ):
1869- if idx not in seen_indices :
1870- seen_indices .add (idx )
1871- unique_queries .append (query )
1872- unique_prompt_lengths = [len (q ) for q in unique_queries ]
1863+ # For training MFU, we need all prompt lengths (including repeated ones)
1864+ # since we're calculating total tokens processed during training
1865+ prompt_lengths = [len (q ) for q in batch .queries ]
18731866
18741867 packed_sequences_Q .put (
18751868 PackedData (
@@ -1879,7 +1872,7 @@ def data_preparation_thread(
18791872 responses_count = len (responses ),
18801873 num_new_tokens = num_new_tokens ,
18811874 batch_size = B ,
1882- prompt_lengths = unique_prompt_lengths ,
1875+ prompt_lengths = prompt_lengths ,
18831876 response_lengths = [len (r ) for r in responses ],
18841877 )
18851878 )
@@ -2143,7 +2136,7 @@ def load_data_from_packing_thread(
21432136 Returns:
21442137 Tuple of (collated_data, data_thread_metrics, num_total_tokens, packed_data)
21452138 """
2146- with rl_utils2 . Timer ("[Main Thread] 📦 Getting packed sequences from thread" ) as timer :
2139+ with Timer ("[Main Thread] 📦 Getting packed sequences from thread" ) as timer :
21472140 while True :
21482141 if stop_event .is_set ():
21492142 logger .warning ("[Main Thread] Stop event detected while waiting for packed sequences" )
@@ -2189,7 +2182,7 @@ def weight_sync_thread(
21892182 # Clear the event for next iteration
21902183 weight_sync_trigger_event .clear ()
21912184
2192- with rl_utils2 . Timer ("[Weight Sync]" ) as timer :
2185+ with Timer ("[Weight Sync]" ) as timer :
21932186 logger .debug ("[Weight Sync Thread] Starting weight sync" )
21942187
21952188 # Set actors to stop
@@ -2223,7 +2216,7 @@ def generate_thread(args, vllm_engines, resume_training_step, stop_event, genera
22232216 """Thread function that repeatedly calls process_from_queue on vllm engines."""
22242217 logger .info ("[Generate Thread] 🚀 Starting generation thread" )
22252218 while not stop_event .is_set ():
2226- with rl_utils2 . Timer ("🔥 Generation time" ) as timer :
2219+ with Timer ("🔥 Generation time" ) as timer :
22272220 processed_results = ray_get_with_progress (
22282221 [engine .process_from_queue .remote (timeout = 20 ) for engine in vllm_engines ],
22292222 desc = "[Generate Thread] Waiting for vLLM engines to process" ,
@@ -2267,14 +2260,10 @@ def calculate_utilization_metrics(
22672260
22682261 # For training, we need to calculate total sequence lengths (prompt + response)
22692262 # This represents the full sequence that the model is trained on
2263+ # Since we now have all prompts (including repeated ones), we can directly zip
22702264 total_sequence_lengths = []
2271- response_idx = 0
2272- for prompt_len in packed_data .prompt_lengths :
2273- # For each unique prompt, get all its response lengths
2274- for _ in range (args .num_samples_per_prompt_rollout ):
2275- response_len = packed_data .response_lengths [response_idx ]
2276- total_sequence_lengths .append (prompt_len + response_len )
2277- response_idx += 1
2265+ for prompt_len , response_len in zip (packed_data .prompt_lengths , packed_data .response_lengths ):
2266+ total_sequence_lengths .append (prompt_len + response_len )
22782267
22792268 # Create a new ModelDims instance with is_training=True
22802269 training_model_dims = utils .ModelDims (
@@ -2299,7 +2288,8 @@ def calculate_utilization_metrics(
22992288 flops_per_second = total_flops / train_duration
23002289 mfu = 100 * flops_per_second / device_flops
23012290
2302- # MBU is not reported during training as requested
2291+ # We currently only report a single metric. This will expand to include actor MFU/MBU.
2292+ # We don't include MBU as it's currently broken for training.
23032293 return {"mfu" : mfu }
23042294
23052295
@@ -2324,7 +2314,7 @@ def one_training_step(
23242314) -> None :
23252315 """Train the model for one step."""
23262316 update_ref_policy_future = []
2327- with rl_utils2 . Timer ("[Main Thread] 🗡️ Training" ) as train_timer :
2317+ with Timer ("[Main Thread] 🗡️ Training" ) as train_timer :
23282318 metrics_list : List [dict [str , float ]] = ray_get_with_progress (
23292319 [
23302320 policy_group .models [i ].train .remote (
@@ -2345,7 +2335,7 @@ def one_training_step(
23452335
23462336 save_time = 0
23472337 if args .save_freq > 0 and training_step % args .save_freq == 0 and (args .eval_on_step_0 or training_step > 1 ):
2348- with rl_utils2 . Timer ("[Main Thread] 🗡️ Saving model" ) as timer :
2338+ with Timer ("[Main Thread] 🗡️ Saving model" ) as timer :
23492339 checkpoint_dir = f"{ args .output_dir } _checkpoints"
23502340 step_dir = os .path .join (checkpoint_dir , f"step_{ training_step } " )
23512341 logger .info (f"Saving model at step { training_step } to { step_dir } " )
@@ -2365,7 +2355,7 @@ def one_training_step(
23652355 save_time += timer .duration
23662356
23672357 if len (update_ref_policy_future ) > 0 :
2368- with rl_utils2 . Timer ("[Main Thread] 🔃 Updating reference policy" ):
2358+ with Timer ("[Main Thread] 🔃 Updating reference policy" ):
23692359 ray_get_with_progress (update_ref_policy_future , desc = "Updating reference policy" )
23702360
23712361 ray .get (actor_manager .report_training_step_time .remote (train_timer .duration ))
@@ -2499,7 +2489,7 @@ def save_final_model(
24992489):
25002490 """Save the final model and launch evaluation jobs if configured."""
25012491 logger .info (f"Saving final model at step { training_step } to { args .output_dir } " )
2502- with rl_utils2 . Timer ("[Main Thread] 🗡️ Saving model" ):
2492+ with Timer ("[Main Thread] 🗡️ Saving model" ):
25032493 ray_get_with_progress (
25042494 [
25052495 policy_group .models [i ].save_model .remote (args .output_dir , chat_template_name , tokenizer )
@@ -2558,7 +2548,7 @@ async def reward_fn(
25582548 metrics = {}
25592549
25602550 if args .apply_r1_style_format_reward :
2561- with rl_utils2 . Timer ("[Data Preparation Thread] Calculating rewards -- 🧮 Calculating format reward" ):
2551+ with Timer ("[Data Preparation Thread] Calculating rewards -- 🧮 Calculating format reward" ):
25622552 format_scores = soft_format_reward_func (decoded_responses , args .r1_style_format_reward )
25632553 if len (format_scores ) != len (scores ):
25642554 raise ValueError (f"{ len (format_scores )= } != { len (scores )= } " )
@@ -2567,7 +2557,7 @@ async def reward_fn(
25672557 metrics ["val/format_scores" ] = np .array (format_scores ).mean ()
25682558
25692559 if args .apply_verifiable_reward :
2570- with rl_utils2 . Timer ("[Data Preparation Thread] Calculating rewards -- 🏆 Applying verifiable reward" ):
2560+ with Timer ("[Data Preparation Thread] Calculating rewards -- 🏆 Applying verifiable reward" ):
25712561 verifiable_rewards , per_func_rewards = await apply_verifiable_reward (
25722562 reward_fn_mapping ,
25732563 responses ,
@@ -2603,7 +2593,7 @@ async def reward_fn(
26032593
26042594 # this gets applied at the very end since it replaces (rather than adds to) the existing reward.
26052595 if args .non_stop_penalty :
2606- with rl_utils2 . Timer ("[Data Preparation Thread] Calculating rewards -- 🦖 Applying non stop penalty" ):
2596+ with Timer ("[Data Preparation Thread] Calculating rewards -- 🦖 Applying non stop penalty" ):
26072597 assert len (finish_reasons ) == len (scores )
26082598 for i in range (len (finish_reasons )):
26092599 if finish_reasons [i ] != "stop" :
@@ -2845,7 +2835,7 @@ def health_check_fn():
28452835 and training_step % args .checkpoint_state_freq == 0
28462836 and args .checkpoint_state_dir is not None
28472837 ):
2848- with rl_utils2 . Timer ("[Main Thread] 🗡️ Saving checkpoint state" ):
2838+ with Timer ("[Main Thread] 🗡️ Saving checkpoint state" ):
28492839 # Save comprehensive client state including ShufflingIterator state
28502840 client_state = {
28512841 "training_step" : training_step ,
0 commit comments