Skip to content

Commit 674c706

Browse files
Adds a bunch of logs to explain what's happening during filtering. (#980)
* Added logs to explain what's happening during filtering. * Added logging * Cleaned up PR. * Added filtering code. * Added metrics to wandb
1 parent 01daf56 commit 674c706

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

open_instruct/grpo_fast.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,16 @@ def data_preparation_thread(
15981598
real_batch_size_ratio = non_zero_std_mask.sum() * args.num_samples_per_prompt_rollout / len(scores)
15991599
expanded_mask = np.repeat(non_zero_std_mask, args.num_samples_per_prompt_rollout)
16001600
non_zero_gradient_index = np.where(expanded_mask)[0]
1601+
1602+
# Log zero-gradient filtering statistics
1603+
num_zero_std_prompts = (~non_zero_std_mask).sum()
1604+
num_filtered_responses = len(scores) - len(non_zero_gradient_index)
1605+
if num_filtered_responses > 0:
1606+
logger.info(
1607+
f"[Zero-gradient filtering] Filtered {num_zero_std_prompts} prompts with zero std "
1608+
f"({num_filtered_responses} responses). Retention rate: {len(non_zero_gradient_index) / len(scores):.2%}"
1609+
)
1610+
16011611
advantages = advantages[non_zero_gradient_index]
16021612
original_batch_size = len(scores)
16031613
scores = scores[non_zero_gradient_index]
@@ -1607,6 +1617,12 @@ def data_preparation_thread(
16071617
finish_reasons = [result.finish_reasons[i] for i in non_zero_gradient_index]
16081618
if args.mask_truncated_completions:
16091619
stop_idxes = torch.tensor([i for i in range(len(finish_reasons)) if finish_reasons[i] == "stop"])
1620+
num_truncated = len(finish_reasons) - len(stop_idxes)
1621+
if num_truncated > 0:
1622+
logger.info(
1623+
f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. "
1624+
f"Retention rate: {len(stop_idxes) / len(finish_reasons):.2%}"
1625+
)
16101626
scores = scores[stop_idxes]
16111627
advantages = advantages[stop_idxes]
16121628
responses = [responses[i] for i in stop_idxes]
@@ -1627,6 +1643,11 @@ def data_preparation_thread(
16271643
stds = scores_matrix.std(axis=1) + 1e-8
16281644
probs = stds / stds.sum()
16291645

1646+
logger.info(
1647+
f"[Refill completions] Need to fill {need_to_fill_prompt} prompts to maintain batch size. "
1648+
f"Original: {original_prompt_cnt}, Current: {current_prompt_cnt}"
1649+
)
1650+
16301651
sampled_prompt_ids = np.random.choice(
16311652
current_prompt_cnt, size=need_to_fill_prompt, replace=True, p=probs
16321653
)
@@ -1652,10 +1673,18 @@ def data_preparation_thread(
16521673

16531674
finish_reasons += [finish_reasons[i] for i in sampled_indices]
16541675

1655-
print(
1676+
logger.info(
16561677
f"📊 Duplicated {need_to_fill_prompt} prompts from {len(sampled_indices)} total responses"
16571678
)
16581679

1680+
# Count groups with all zero rewards
1681+
all_zero_groups = (scores_per_prompt == 0).all(axis=-1).sum()
1682+
total_groups = len(scores_per_prompt)
1683+
logger.info(
1684+
f"[Reward Summary] Groups with all zero rewards: {all_zero_groups}/{total_groups} "
1685+
f"({all_zero_groups / total_groups:.1%})"
1686+
)
1687+
16591688
with Timer("📦 [Data Preparation Thread] Packing sequences"):
16601689
packed_sequences = pack_sequences(
16611690
queries=batch.queries,
@@ -1767,11 +1796,18 @@ def data_preparation_thread(
17671796
sequence_length_unsolved = (
17681797
np.array([]) if np.all(scores == max_possible_score) else np.array(sequence_lengths[scores == 0])
17691798
)
1799+
1800+
# Use the already calculated reward summary metrics for wandb
1801+
all_zero_groups_ratio = all_zero_groups / total_groups if total_groups > 0 else 0
1802+
17701803
metrics = {
17711804
"scores": np.array(scores).mean(),
17721805
"real_batch_size_ratio": real_batch_size_ratio,
17731806
"unsolved_batch_size_ratio": unsolved_batch_size_ratio,
17741807
"packed_ratio": len(packed_sequences.query_responses) / len(responses) if len(responses) > 0 else 0,
1808+
"val/all_zero_reward_groups": all_zero_groups,
1809+
"val/all_zero_reward_groups_ratio": all_zero_groups_ratio,
1810+
"val/total_reward_groups": total_groups,
17751811
"val/sequence_lengths": sequence_lengths.mean(),
17761812
"val/sequence_lengths_min": sequence_lengths.min(),
17771813
"val/sequence_lengths_max": sequence_lengths.max(),

0 commit comments

Comments
 (0)