@@ -2766,16 +2766,17 @@ def get_loss(q_block, q_input_ids):
27662766 logger .info (f"loss_ratio [mxfp4_loss / mxfp8_loss]: { mxfp4_loss / mxfp8_loss } " )
27672767 if is_hpex_available ():
27682768 htcore .mark_step ()
2769- if int (block_name .split ("." )[- 1 ]) == 0 :
2770- self .target_loss_ratio = (mxfp4_loss / mxfp8_loss ) * (1 - mp_ratio )
2771- logger .warning_once (f"[Recipe Mode] Based on the mp_ratio, we set the target_loss_ratio: { self .target_loss_ratio } " )
2772- if mxfp4_loss / mxfp8_loss > self .target_loss_ratio :
2773- quantizable_num += 1
2774- logger .warning (f"[Recipe Mode] Due to [mxfp4_loss / mxfp8_loss]: { mxfp4_loss / mxfp8_loss } > { self .target_loss_ratio } " )
2769+ # aggressive mode
2770+ top_loss_ratio = 6
2771+ target_loss_ratio = 2
2772+ logger .warning_once (f"[Recipe Mode] Aggressive mode, we set the top_loss_ratio: { top_loss_ratio } , target_loss_ratio: { target_loss_ratio } " )
2773+ if mxfp4_loss / mxfp8_loss > top_loss_ratio :
2774+ logger .warning (f"[Recipe Mode] Due to loss_ratio [mxfp4_loss / mxfp8_loss]: { mxfp4_loss / mxfp8_loss } > { top_loss_ratio } " )
2775+ quantizable_num = len (quantizable_layers )
27752776 logger .warning (f"[Recipe Mode] Set { quantizable_num } layers using mixed precision for this block." )
2776- elif mxfp4_loss / mxfp8_loss < 1 : # special case for llama3.3 70B
2777- quantizable_num -= 1
2778- logger . warning ( f"[Recipe Mode] Due to [mxfp4_loss / mxfp8_loss]: { mxfp4_loss / mxfp8_loss } < 1" )
2777+ elif target_loss_ratio < mxfp4_loss / mxfp8_loss <= top_loss_ratio :
2778+ logger . warning ( f"[Recipe Mode] Due to loss_ratio [mxfp4_loss / mxfp8_loss]: { target_loss_ratio } < { mxfp4_loss / mxfp8_loss } <= { top_loss_ratio } " )
2779+ quantizable_num = quantizable_num + 1
27792780 logger .warning (f"[Recipe Mode] Set { quantizable_num } layers using mixed precision for this block." )
27802781 combination_list = []
27812782 avg_bits_list = []
@@ -2802,7 +2803,7 @@ def get_loss(q_block, q_input_ids):
28022803 best_avg_bits = avg_bits
28032804 best_combination = combination_list [i ]
28042805
2805- logger .info (f"[Recipe Mode] Recipe results of { block_name } :\n Mix precision layers: { best_combination } ;\n Average bits: { best_avg_bits } ." )
2806+ logger .info (f"[Recipe Mode] Recipe results of { block_name } :\n Mix precision layers: { best_combination } ;\n Average bits: { best_avg_bits } ; Loss ratio: { best_loss / mxfp8_loss } ." )
28062807 # generate output of quantized block of sample input_ids
28072808 block = create_mp_block (block , best_combination , self .recipe_mp_dtype )
28082809 q_output = get_output (block , q_input_ids )
0 commit comments