Skip to content

Commit ade7736

Browse files
committed
add aggressive mode
Signed-off-by: xinhe3 <[email protected]>
1 parent 92025ad commit ade7736

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

auto_round/utils.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2766,16 +2766,17 @@ def get_loss(q_block, q_input_ids):
27662766
logger.info(f"loss_ratio [mxfp4_loss / mxfp8_loss]: {mxfp4_loss/mxfp8_loss}")
27672767
if is_hpex_available():
27682768
htcore.mark_step()
2769-
if int(block_name.split(".")[-1]) == 0:
2770-
self.target_loss_ratio = (mxfp4_loss / mxfp8_loss) * (1 - mp_ratio)
2771-
logger.warning_once(f"[Recipe Mode] Based on the mp_ratio, we set the target_loss_ratio: {self.target_loss_ratio}")
2772-
if mxfp4_loss / mxfp8_loss > self.target_loss_ratio:
2773-
quantizable_num += 1
2774-
logger.warning(f"[Recipe Mode] Due to [mxfp4_loss / mxfp8_loss]: {mxfp4_loss / mxfp8_loss} > {self.target_loss_ratio}")
2769+
# aggressive mode
2770+
top_loss_ratio = 6
2771+
target_loss_ratio = 2
2772+
logger.warning_once(f"[Recipe Mode] Aggressive mode, we set the top_loss_ratio: {top_loss_ratio}, target_loss_ratio: {target_loss_ratio}")
2773+
if mxfp4_loss / mxfp8_loss > top_loss_ratio:
2774+
logger.warning(f"[Recipe Mode] Due to loss_ratio [mxfp4_loss / mxfp8_loss]: {mxfp4_loss / mxfp8_loss} > {top_loss_ratio}")
2775+
quantizable_num = len(quantizable_layers)
27752776
logger.warning(f"[Recipe Mode] Set {quantizable_num} layers using mixed precision for this block.")
2776-
elif mxfp4_loss / mxfp8_loss < 1: # special case for llama3.3 70B
2777-
quantizable_num -= 1
2778-
logger.warning(f"[Recipe Mode] Due to [mxfp4_loss / mxfp8_loss]: {mxfp4_loss / mxfp8_loss} < 1")
2777+
elif target_loss_ratio < mxfp4_loss / mxfp8_loss <= top_loss_ratio:
2778+
logger.warning(f"[Recipe Mode] Due to loss_ratio [mxfp4_loss / mxfp8_loss]:{target_loss_ratio} < {mxfp4_loss / mxfp8_loss} <= {top_loss_ratio}")
2779+
quantizable_num = quantizable_num + 1
27792780
logger.warning(f"[Recipe Mode] Set {quantizable_num} layers using mixed precision for this block.")
27802781
combination_list = []
27812782
avg_bits_list = []
@@ -2802,7 +2803,7 @@ def get_loss(q_block, q_input_ids):
28022803
best_avg_bits = avg_bits
28032804
best_combination = combination_list[i]
28042805

2805-
logger.info(f"[Recipe Mode] Recipe results of {block_name}:\nMix precision layers: {best_combination};\nAverage bits: {best_avg_bits}.")
2806+
logger.info(f"[Recipe Mode] Recipe results of {block_name}:\nMix precision layers: {best_combination};\nAverage bits: {best_avg_bits}; Loss ratio: {best_loss/mxfp8_loss}.")
28062807
# generate output of quantized block of sample input_ids
28072808
block = create_mp_block(block, best_combination, self.recipe_mp_dtype)
28082809
q_output = get_output(block, q_input_ids)

0 commit comments

Comments
 (0)