@@ -1598,6 +1598,16 @@ def data_preparation_thread(
15981598 real_batch_size_ratio = non_zero_std_mask .sum () * args .num_samples_per_prompt_rollout / len (scores )
15991599 expanded_mask = np .repeat (non_zero_std_mask , args .num_samples_per_prompt_rollout )
16001600 non_zero_gradient_index = np .where (expanded_mask )[0 ]
1601+
1602+ # Log zero-gradient filtering statistics
1603+ num_zero_std_prompts = (~ non_zero_std_mask ).sum ()
1604+ num_filtered_responses = len (scores ) - len (non_zero_gradient_index )
1605+ if num_filtered_responses > 0 :
1606+ logger .info (
1607+ f"[Zero-gradient filtering] Filtered { num_zero_std_prompts } prompts with zero std "
1608+ f"({ num_filtered_responses } responses). Retention rate: { len (non_zero_gradient_index ) / len (scores ):.2%} "
1609+ )
1610+
16011611 advantages = advantages [non_zero_gradient_index ]
16021612 original_batch_size = len (scores )
16031613 scores = scores [non_zero_gradient_index ]
@@ -1607,6 +1617,12 @@ def data_preparation_thread(
16071617 finish_reasons = [result .finish_reasons [i ] for i in non_zero_gradient_index ]
16081618 if args .mask_truncated_completions :
16091619 stop_idxes = torch .tensor ([i for i in range (len (finish_reasons )) if finish_reasons [i ] == "stop" ])
1620+ num_truncated = len (finish_reasons ) - len (stop_idxes )
1621+ if num_truncated > 0 :
1622+ logger .info (
1623+ f"[Truncated completions filtering] Filtered { num_truncated } responses that didn't finish with 'stop'. "
1624+ f"Retention rate: { len (stop_idxes ) / len (finish_reasons ):.2%} "
1625+ )
16101626 scores = scores [stop_idxes ]
16111627 advantages = advantages [stop_idxes ]
16121628 responses = [responses [i ] for i in stop_idxes ]
@@ -1627,6 +1643,11 @@ def data_preparation_thread(
16271643 stds = scores_matrix .std (axis = 1 ) + 1e-8
16281644 probs = stds / stds .sum ()
16291645
1646+ logger .info (
1647+ f"[Refill completions] Need to fill { need_to_fill_prompt } prompts to maintain batch size. "
1648+ f"Original: { original_prompt_cnt } , Current: { current_prompt_cnt } "
1649+ )
1650+
16301651 sampled_prompt_ids = np .random .choice (
16311652 current_prompt_cnt , size = need_to_fill_prompt , replace = True , p = probs
16321653 )
@@ -1652,10 +1673,18 @@ def data_preparation_thread(
16521673
16531674 finish_reasons += [finish_reasons [i ] for i in sampled_indices ]
16541675
1655- print (
1676+ logger . info (
16561677 f"📊 Duplicated { need_to_fill_prompt } prompts from { len (sampled_indices )} total responses"
16571678 )
16581679
1680+ # Count groups with all zero rewards
1681+ all_zero_groups = (scores_per_prompt == 0 ).all (axis = - 1 ).sum ()
1682+ total_groups = len (scores_per_prompt )
1683+ logger .info (
1684+ f"[Reward Summary] Groups with all zero rewards: { all_zero_groups } /{ total_groups } "
1685+ f"({ all_zero_groups / total_groups :.1%} )"
1686+ )
1687+
16591688 with Timer ("📦 [Data Preparation Thread] Packing sequences" ):
16601689 packed_sequences = pack_sequences (
16611690 queries = batch .queries ,
@@ -1767,11 +1796,18 @@ def data_preparation_thread(
17671796 sequence_length_unsolved = (
17681797 np .array ([]) if np .all (scores == max_possible_score ) else np .array (sequence_lengths [scores == 0 ])
17691798 )
1799+
1800+ # Use the already calculated reward summary metrics for wandb
1801+ all_zero_groups_ratio = all_zero_groups / total_groups if total_groups > 0 else 0
1802+
17701803 metrics = {
17711804 "scores" : np .array (scores ).mean (),
17721805 "real_batch_size_ratio" : real_batch_size_ratio ,
17731806 "unsolved_batch_size_ratio" : unsolved_batch_size_ratio ,
17741807 "packed_ratio" : len (packed_sequences .query_responses ) / len (responses ) if len (responses ) > 0 else 0 ,
1808+ "val/all_zero_reward_groups" : all_zero_groups ,
1809+ "val/all_zero_reward_groups_ratio" : all_zero_groups_ratio ,
1810+ "val/total_reward_groups" : total_groups ,
17751811 "val/sequence_lengths" : sequence_lengths .mean (),
17761812 "val/sequence_lengths_min" : sequence_lengths .min (),
17771813 "val/sequence_lengths_max" : sequence_lengths .max (),
0 commit comments