Skip to content

Commit 956076b

Browse files
Cleaned up PR significantly.
1 parent 28b8ba0 commit 956076b

File tree

1 file changed

+37
-47
lines changed

1 file changed

+37
-47
lines changed

open_instruct/grpo_fast.py

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, get_scheduler
8383
from transformers.integrations import HfDeepSpeedConfig
8484

85-
from open_instruct import logger_utils, rl_utils2, vllm_utils3
85+
from open_instruct import logger_utils, vllm_utils3
8686
from open_instruct.actor_manager import ActorManager
8787
from open_instruct.dataset_transformation import (
8888
GROUND_TRUTHS_KEY,
@@ -111,6 +111,7 @@
111111
push_folder_to_hub,
112112
)
113113
from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics
114+
from open_instruct.rl_utils2 import PackedSequences, Timer, pack_sequences
114115
from open_instruct.utils import (
115116
ArgumentParserPlus,
116117
BeakerRuntimeConfig,
@@ -149,7 +150,7 @@ class ShutdownSentinel:
149150
class PackedData:
150151
"""Container for packed sequences and associated metadata."""
151152

152-
packed_sequences: rl_utils2.PackedSequences
153+
packed_sequences: PackedSequences
153154
collated_data: list # Collated training data for each device
154155
metrics: dict # Training metrics
155156
responses_count: int # Number of responses
@@ -931,7 +932,7 @@ def train(
931932

932933
# Calculate the logprob of the reference policy
933934
collated_ref_logprobs = []
934-
with rl_utils2.Timer("Inference Calculation", noop=self.rank != 0):
935+
with Timer("Inference Calculation", noop=self.rank != 0):
935936
with torch.no_grad():
936937
for i in range(len(collated_query_responses)):
937938
query_response = collated_query_responses[i]
@@ -961,7 +962,7 @@ def train(
961962
# from the generator (note that async mode means these are a bit diff!)
962963
old_logprobs = [None for _ in range(len(collated_query_responses))]
963964
if num_mini_batches > 1:
964-
with rl_utils2.Timer("Old logprobs Calculation", noop=self.rank != 0):
965+
with Timer("Old logprobs Calculation", noop=self.rank != 0):
965966
with torch.no_grad():
966967
for i in range(len(collated_query_responses)):
967968
query_response = collated_query_responses[i]
@@ -988,7 +989,7 @@ def train(
988989

989990
local_step = 0
990991
# Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch
991-
with rl_utils2.Timer("[Training Processes] Loss calculation", noop=self.rank != 0):
992+
with Timer("[Training Processes] Loss calculation", noop=self.rank != 0):
992993
kl1_stats = torch.zeros(len(collated_query_responses))
993994
kl2_stats = torch.zeros(len(collated_query_responses))
994995
kl3_stats = torch.zeros(len(collated_query_responses))
@@ -1411,7 +1412,6 @@ def accumulate_inference_batches(
14111412
all_ground_truths = []
14121413
all_datasets = []
14131414
all_raw_queries = []
1414-
all_indices = []
14151415
for i in tqdm(
14161416
range(num_prompts),
14171417
total=num_prompts,
@@ -1438,7 +1438,6 @@ def accumulate_inference_batches(
14381438
all_ground_truths.append(ground_truth)
14391439
all_datasets.append(dataset)
14401440
all_raw_queries.append(raw_query)
1441-
all_indices.append(result.dataset_index)
14421441

14431442
# Combine all results into a single GenerationResult
14441443
combined_responses = []
@@ -1495,13 +1494,13 @@ def accumulate_inference_batches(
14951494
if actor_manager is not None:
14961495
ray.get(actor_manager.report_token_statistics.remote(accumulated_stats))
14971496

1498-
# Create batch with preserved dataset indices
1497+
# Note: We don't have dataset_indices here, but they're not needed for the returned batch
14991498
batch = Batch(
15001499
queries=all_queries,
15011500
ground_truths=all_ground_truths,
15021501
datasets=all_datasets,
15031502
raw_queries=all_raw_queries,
1504-
indices=all_indices, # Preserve the dataset indices for MFU/MBU calculations
1503+
indices=None, # Not meaningful for combined results
15051504
)
15061505
return combined_result, batch
15071506

@@ -1520,7 +1519,7 @@ def data_preparation_thread(
15201519
):
15211520
for training_step in range(resume_training_step, num_training_steps + 1):
15221521
# Streaming accumulation: collect results as they arrive
1523-
with rl_utils2.Timer("🚀 [Data Preparation Thread] Getting response ids") as timer:
1522+
with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer:
15241523
result, batch = accumulate_inference_batches(
15251524
inference_results_Q,
15261525
pending_queries_map,
@@ -1562,14 +1561,14 @@ def data_preparation_thread(
15621561
):
15631562
result.responses[i].append(tokenizer.eos_token_id)
15641563
result.masks[i].append(1) # never mask the eos token for
1565-
with rl_utils2.Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
1564+
with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
15661565
decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True)
15671566
decoded_queries = batch.raw_queries
15681567
stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len(
15691568
result.finish_reasons
15701569
)
15711570

1572-
with rl_utils2.Timer("💰 [Data Preparation Thread] Calculating rewards and advantages"):
1571+
with Timer("💰 [Data Preparation Thread] Calculating rewards and advantages"):
15731572
scores, reward_metrics = asyncio.run(
15741573
reward_fn(
15751574
result.responses,
@@ -1593,7 +1592,7 @@ def data_preparation_thread(
15931592
else:
15941593
raise ValueError(f"Invalid advantage normalization type: {args.advantage_normalization_type}")
15951594

1596-
with rl_utils2.Timer("📦 [Data Preparation Thread] Filtering sequences"):
1595+
with Timer("📦 [Data Preparation Thread] Filtering sequences"):
15971596
# Here we get the max possible score for each prompt, and see how many prompts are unsolved
15981597
max_possible_score = 0
15991598
if args.apply_verifiable_reward:
@@ -1640,7 +1639,7 @@ def data_preparation_thread(
16401639
finish_reasons = [finish_reasons[i] for i in stop_idxes]
16411640

16421641
if args.fill_completions:
1643-
with rl_utils2.Timer("⏱ [Data Preparation Thread] Refill completions"):
1642+
with Timer("⏱ [Data Preparation Thread] Refill completions"):
16441643
current_batch_size = len(scores)
16451644
original_prompt_cnt = original_batch_size // args.num_samples_per_prompt_rollout
16461645
current_prompt_cnt = current_batch_size // args.num_samples_per_prompt_rollout
@@ -1694,8 +1693,8 @@ def data_preparation_thread(
16941693
f"({all_zero_groups / total_groups:.1%})"
16951694
)
16961695

1697-
with rl_utils2.Timer("📦 [Data Preparation Thread] Packing sequences"):
1698-
packed_sequences = rl_utils2.pack_sequences(
1696+
with Timer("📦 [Data Preparation Thread] Packing sequences"):
1697+
packed_sequences = pack_sequences(
16991698
queries=batch.queries,
17001699
responses=responses,
17011700
masks=masks,
@@ -1716,7 +1715,7 @@ def data_preparation_thread(
17161715
# if we have less batches than world size, we need to pad out so each world is fine
17171716
# ideally, you should avoid this since its wasting computation.
17181717
if args.allow_world_padding:
1719-
with rl_utils2.Timer("🤺 [Data Preparation Thread] Padding sequences for world size"):
1718+
with Timer("🤺 [Data Preparation Thread] Padding sequences for world size"):
17201719
shortfall = args.world_size - len(packed_sequences.query_responses)
17211720
if shortfall > 0:
17221721
logger.warning(
@@ -1738,7 +1737,7 @@ def data_preparation_thread(
17381737
packed_sequences.response_masks.append(dummy_response_mask)
17391738
packed_sequences.advantages.append(dummy_advantage)
17401739

1741-
with rl_utils2.Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker"):
1740+
with Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker"):
17421741
B = (
17431742
len(packed_sequences.query_responses) // args.world_size
17441743
) # essentially doing `drop_last=True`, which is fine.
@@ -1861,15 +1860,9 @@ def data_preparation_thread(
18611860
logger.warning(f"No responses in batch {training_step}.")
18621861

18631862
# Put the packed sequences and metrics into the output queue
1864-
# For MFU/MBU calculations, we need unique prompt lengths, not repeated ones
1865-
# Use indices to identify unique prompts
1866-
seen_indices = set()
1867-
unique_queries = []
1868-
for idx, query in zip(batch.indices, batch.queries):
1869-
if idx not in seen_indices:
1870-
seen_indices.add(idx)
1871-
unique_queries.append(query)
1872-
unique_prompt_lengths = [len(q) for q in unique_queries]
1863+
# For training MFU, we need all prompt lengths (including repeated ones)
1864+
# since we're calculating total tokens processed during training
1865+
prompt_lengths = [len(q) for q in batch.queries]
18731866

18741867
packed_sequences_Q.put(
18751868
PackedData(
@@ -1879,7 +1872,7 @@ def data_preparation_thread(
18791872
responses_count=len(responses),
18801873
num_new_tokens=num_new_tokens,
18811874
batch_size=B,
1882-
prompt_lengths=unique_prompt_lengths,
1875+
prompt_lengths=prompt_lengths,
18831876
response_lengths=[len(r) for r in responses],
18841877
)
18851878
)
@@ -2143,7 +2136,7 @@ def load_data_from_packing_thread(
21432136
Returns:
21442137
Tuple of (collated_data, data_thread_metrics, num_total_tokens, packed_data)
21452138
"""
2146-
with rl_utils2.Timer("[Main Thread] 📦 Getting packed sequences from thread") as timer:
2139+
with Timer("[Main Thread] 📦 Getting packed sequences from thread") as timer:
21472140
while True:
21482141
if stop_event.is_set():
21492142
logger.warning("[Main Thread] Stop event detected while waiting for packed sequences")
@@ -2189,7 +2182,7 @@ def weight_sync_thread(
21892182
# Clear the event for next iteration
21902183
weight_sync_trigger_event.clear()
21912184

2192-
with rl_utils2.Timer("[Weight Sync]") as timer:
2185+
with Timer("[Weight Sync]") as timer:
21932186
logger.debug("[Weight Sync Thread] Starting weight sync")
21942187

21952188
# Set actors to stop
@@ -2223,7 +2216,7 @@ def generate_thread(args, vllm_engines, resume_training_step, stop_event, genera
22232216
"""Thread function that repeatedly calls process_from_queue on vllm engines."""
22242217
logger.info("[Generate Thread] 🚀 Starting generation thread")
22252218
while not stop_event.is_set():
2226-
with rl_utils2.Timer("🔥 Generation time") as timer:
2219+
with Timer("🔥 Generation time") as timer:
22272220
processed_results = ray_get_with_progress(
22282221
[engine.process_from_queue.remote(timeout=20) for engine in vllm_engines],
22292222
desc="[Generate Thread] Waiting for vLLM engines to process",
@@ -2267,14 +2260,10 @@ def calculate_utilization_metrics(
22672260

22682261
# For training, we need to calculate total sequence lengths (prompt + response)
22692262
# This represents the full sequence that the model is trained on
2263+
# Since we now have all prompts (including repeated ones), we can directly zip
22702264
total_sequence_lengths = []
2271-
response_idx = 0
2272-
for prompt_len in packed_data.prompt_lengths:
2273-
# For each unique prompt, get all its response lengths
2274-
for _ in range(args.num_samples_per_prompt_rollout):
2275-
response_len = packed_data.response_lengths[response_idx]
2276-
total_sequence_lengths.append(prompt_len + response_len)
2277-
response_idx += 1
2265+
for prompt_len, response_len in zip(packed_data.prompt_lengths, packed_data.response_lengths):
2266+
total_sequence_lengths.append(prompt_len + response_len)
22782267

22792268
# Create a new ModelDims instance with is_training=True
22802269
training_model_dims = utils.ModelDims(
@@ -2299,7 +2288,8 @@ def calculate_utilization_metrics(
22992288
flops_per_second = total_flops / train_duration
23002289
mfu = 100 * flops_per_second / device_flops
23012290

2302-
# MBU is not reported during training as requested
2291+
# We currently only report a single metric. This will expand to include actor MFU/MBU.
2292+
# We don't include MBU as it's currently broken for training.
23032293
return {"mfu": mfu}
23042294

23052295

@@ -2324,7 +2314,7 @@ def one_training_step(
23242314
) -> None:
23252315
"""Train the model for one step."""
23262316
update_ref_policy_future = []
2327-
with rl_utils2.Timer("[Main Thread] 🗡️ Training") as train_timer:
2317+
with Timer("[Main Thread] 🗡️ Training") as train_timer:
23282318
metrics_list: List[dict[str, float]] = ray_get_with_progress(
23292319
[
23302320
policy_group.models[i].train.remote(
@@ -2345,7 +2335,7 @@ def one_training_step(
23452335

23462336
save_time = 0
23472337
if args.save_freq > 0 and training_step % args.save_freq == 0 and (args.eval_on_step_0 or training_step > 1):
2348-
with rl_utils2.Timer("[Main Thread] 🗡️ Saving model") as timer:
2338+
with Timer("[Main Thread] 🗡️ Saving model") as timer:
23492339
checkpoint_dir = f"{args.output_dir}_checkpoints"
23502340
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
23512341
logger.info(f"Saving model at step {training_step} to {step_dir}")
@@ -2365,7 +2355,7 @@ def one_training_step(
23652355
save_time += timer.duration
23662356

23672357
if len(update_ref_policy_future) > 0:
2368-
with rl_utils2.Timer("[Main Thread] 🔃 Updating reference policy"):
2358+
with Timer("[Main Thread] 🔃 Updating reference policy"):
23692359
ray_get_with_progress(update_ref_policy_future, desc="Updating reference policy")
23702360

23712361
ray.get(actor_manager.report_training_step_time.remote(train_timer.duration))
@@ -2499,7 +2489,7 @@ def save_final_model(
24992489
):
25002490
"""Save the final model and launch evaluation jobs if configured."""
25012491
logger.info(f"Saving final model at step {training_step} to {args.output_dir}")
2502-
with rl_utils2.Timer("[Main Thread] 🗡️ Saving model"):
2492+
with Timer("[Main Thread] 🗡️ Saving model"):
25032493
ray_get_with_progress(
25042494
[
25052495
policy_group.models[i].save_model.remote(args.output_dir, chat_template_name, tokenizer)
@@ -2558,7 +2548,7 @@ async def reward_fn(
25582548
metrics = {}
25592549

25602550
if args.apply_r1_style_format_reward:
2561-
with rl_utils2.Timer("[Data Preparation Thread] Calculating rewards -- 🧮 Calculating format reward"):
2551+
with Timer("[Data Preparation Thread] Calculating rewards -- 🧮 Calculating format reward"):
25622552
format_scores = soft_format_reward_func(decoded_responses, args.r1_style_format_reward)
25632553
if len(format_scores) != len(scores):
25642554
raise ValueError(f"{len(format_scores)=} != {len(scores)=}")
@@ -2567,7 +2557,7 @@ async def reward_fn(
25672557
metrics["val/format_scores"] = np.array(format_scores).mean()
25682558

25692559
if args.apply_verifiable_reward:
2570-
with rl_utils2.Timer("[Data Preparation Thread] Calculating rewards -- 🏆 Applying verifiable reward"):
2560+
with Timer("[Data Preparation Thread] Calculating rewards -- 🏆 Applying verifiable reward"):
25712561
verifiable_rewards, per_func_rewards = await apply_verifiable_reward(
25722562
reward_fn_mapping,
25732563
responses,
@@ -2603,7 +2593,7 @@ async def reward_fn(
26032593

26042594
# this gets applied at the very end since it replaces (rather than adds to) the existing reward.
26052595
if args.non_stop_penalty:
2606-
with rl_utils2.Timer("[Data Preparation Thread] Calculating rewards -- 🦖 Applying non stop penalty"):
2596+
with Timer("[Data Preparation Thread] Calculating rewards -- 🦖 Applying non stop penalty"):
26072597
assert len(finish_reasons) == len(scores)
26082598
for i in range(len(finish_reasons)):
26092599
if finish_reasons[i] != "stop":
@@ -2845,7 +2835,7 @@ def health_check_fn():
28452835
and training_step % args.checkpoint_state_freq == 0
28462836
and args.checkpoint_state_dir is not None
28472837
):
2848-
with rl_utils2.Timer("[Main Thread] 🗡️ Saving checkpoint state"):
2838+
with Timer("[Main Thread] 🗡️ Saving checkpoint state"):
28492839
# Save comprehensive client state including ShufflingIterator state
28502840
client_state = {
28512841
"training_step": training_step,

0 commit comments

Comments
 (0)