@@ -2683,7 +2683,7 @@ def _generate_recipe(
26832683 return self .recipe_results
26842684
26852685
2686- def _generate_block_recipe (self , block , block_name , input_ids , q_input_ids , input_others ):
2686+ def _generate_block_recipe (self , block , block_name , input_ids , q_input_ids , ref_q_input_ids , input_others ):
26872687 from itertools import combinations
26882688
26892689 # fetch mix-precision recipe configuration
@@ -2739,6 +2739,7 @@ def get_output(block, input_ids):
27392739
27402740 reference_output = get_output (block , input_ids )
27412741 q_input_ids = input_ids if q_input_ids is None else q_input_ids
2742+ ref_q_input_ids = input_ids if ref_q_input_ids is None else ref_q_input_ids
27422743 # generate q_output of sample input_ids and get loss
27432744 @torch .no_grad ()
27442745 def get_loss (q_block , q_input_ids ):
@@ -2757,6 +2758,7 @@ def get_loss(q_block, q_input_ids):
27572758 # get mxfp8 loss
27582759 hp_layers = quantizable_layers
27592760 block = create_mp_block (block , hp_layers , self .recipe_mp_dtype )
2761+ ref_q_output = get_output (block , ref_q_input_ids )
27602762 mxfp8_loss = get_loss (block , q_input_ids )
27612763 block = recover_mp_block (block , hp_layers , raw_dtype )
27622764 hp_layers = []
@@ -2767,8 +2769,8 @@ def get_loss(q_block, q_input_ids):
27672769 if is_hpex_available ():
27682770 htcore .mark_step ()
27692771 # aggressive mode
2770- top_loss_ratio = 6
2771- target_loss_ratio = 2
2772+ top_loss_ratio = 4
2773+ target_loss_ratio = 1.5
27722774 logger .warning_once (f"[Recipe Mode] Aggressive mode, we set the top_loss_ratio: { top_loss_ratio } , target_loss_ratio: { target_loss_ratio } " )
27732775 if mxfp4_loss / mxfp8_loss > top_loss_ratio :
27742776 logger .warning (f"[Recipe Mode] Due to loss_ratio [mxfp4_loss / mxfp8_loss]: { mxfp4_loss / mxfp8_loss } > { top_loss_ratio } " )
@@ -2778,6 +2780,10 @@ def get_loss(q_block, q_input_ids):
27782780 logger .warning (f"[Recipe Mode] Due to loss_ratio [mxfp4_loss / mxfp8_loss]:{ target_loss_ratio } < { mxfp4_loss / mxfp8_loss } <= { top_loss_ratio } " )
27792781 quantizable_num = quantizable_num + 1
27802782 logger .warning (f"[Recipe Mode] Set { quantizable_num } layers using mixed precision for this block." )
2783+ elif mxfp4_loss / mxfp8_loss <= 1.0 :
2784+ logger .warning (f"[Recipe Mode] Due to loss_ratio [mxfp4_loss / mxfp8_loss]: { mxfp4_loss / mxfp8_loss } <= 1.0" )
2785+ quantizable_num = 0
2786+ logger .warning (f"[Recipe Mode] Set { quantizable_num } layers using mixed precision for this block." )
27812787 combination_list = []
27822788 avg_bits_list = []
27832789 loss_list = []
@@ -2822,7 +2828,7 @@ def get_loss(q_block, q_input_ids):
28222828 if is_hpex_available ():
28232829 htcore .mark_step ()
28242830
2825- return reference_output , q_output
2831+ return reference_output , q_output , ref_q_output
28262832
28272833
28282834###############################################################################################
0 commit comments