Skip to content

Commit 580ec8b

Browse files
committed
add ref_q and tight threshold
Signed-off-by: xinhe3 <[email protected]>
1 parent ade7736 commit 580ec8b

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

auto_round/autoround.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2642,13 +2642,13 @@ def quantize_blocks(
26422642
input_others = to_device(input_others, self.cache_device)
26432643
if self.recipe_mode:
26442644
logger.info("[Recipe Mode] starts")
2645-
q_input_ids = None # init value
2645+
q_input_ids, ref_q_input_ids = None, None # init value
26462646
for block_name in tqdm(block_names):
26472647
block = get_module(model, block_name)
26482648
if not self.model.device.type == "meta" or self.low_cpu_mem_usage:
26492649
block = block.to(device)
2650-
input_ids, q_input_ids = self._generate_block_recipe(
2651-
block, block_name, input_ids, q_input_ids, input_others
2650+
input_ids, q_input_ids, ref_q_input_ids = self._generate_block_recipe(
2651+
block, block_name, input_ids, q_input_ids, ref_q_input_ids, input_others
26522652
)
26532653
if is_hpex_available():
26542654
htcore.mark_step()

auto_round/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2683,7 +2683,7 @@ def _generate_recipe(
26832683
return self.recipe_results
26842684

26852685

2686-
def _generate_block_recipe(self, block, block_name, input_ids, q_input_ids, input_others):
2686+
def _generate_block_recipe(self, block, block_name, input_ids, q_input_ids, ref_q_input_ids, input_others):
26872687
from itertools import combinations
26882688

26892689
# fetch mix-precision recipe configuration
@@ -2739,6 +2739,7 @@ def get_output(block, input_ids):
27392739

27402740
reference_output = get_output(block, input_ids)
27412741
q_input_ids = input_ids if q_input_ids is None else q_input_ids
2742+
ref_q_input_ids = input_ids if ref_q_input_ids is None else ref_q_input_ids
27422743
# generate q_output of sample input_ids and get loss
27432744
@torch.no_grad()
27442745
def get_loss(q_block, q_input_ids):
@@ -2757,6 +2758,7 @@ def get_loss(q_block, q_input_ids):
27572758
# get mxfp8 loss
27582759
hp_layers = quantizable_layers
27592760
block = create_mp_block(block, hp_layers, self.recipe_mp_dtype)
2761+
ref_q_output = get_output(block, ref_q_input_ids)
27602762
mxfp8_loss = get_loss(block, q_input_ids)
27612763
block = recover_mp_block(block, hp_layers, raw_dtype)
27622764
hp_layers = []
@@ -2767,8 +2769,8 @@ def get_loss(q_block, q_input_ids):
27672769
if is_hpex_available():
27682770
htcore.mark_step()
27692771
# aggressive mode
2770-
top_loss_ratio = 6
2771-
target_loss_ratio = 2
2772+
top_loss_ratio = 4
2773+
target_loss_ratio = 1.5
27722774
logger.warning_once(f"[Recipe Mode] Aggressive mode, we set the top_loss_ratio: {top_loss_ratio}, target_loss_ratio: {target_loss_ratio}")
27732775
if mxfp4_loss / mxfp8_loss > top_loss_ratio:
27742776
logger.warning(f"[Recipe Mode] Due to loss_ratio [mxfp4_loss / mxfp8_loss]: {mxfp4_loss / mxfp8_loss} > {top_loss_ratio}")
@@ -2778,6 +2780,10 @@ def get_loss(q_block, q_input_ids):
27782780
logger.warning(f"[Recipe Mode] Due to loss_ratio [mxfp4_loss / mxfp8_loss]:{target_loss_ratio} < {mxfp4_loss / mxfp8_loss} <= {top_loss_ratio}")
27792781
quantizable_num = quantizable_num + 1
27802782
logger.warning(f"[Recipe Mode] Set {quantizable_num} layers using mixed precision for this block.")
2783+
elif mxfp4_loss / mxfp8_loss <= 1.0:
2784+
logger.warning(f"[Recipe Mode] Due to loss_ratio [mxfp4_loss / mxfp8_loss]: {mxfp4_loss / mxfp8_loss} <= 1.0")
2785+
quantizable_num = 0
2786+
logger.warning(f"[Recipe Mode] Set {quantizable_num} layers using mixed precision for this block.")
27812787
combination_list = []
27822788
avg_bits_list = []
27832789
loss_list = []
@@ -2822,7 +2828,7 @@ def get_loss(q_block, q_input_ids):
28222828
if is_hpex_available():
28232829
htcore.mark_step()
28242830

2825-
return reference_output, q_output
2831+
return reference_output, q_output, ref_q_output
28262832

28272833

28282834
###############################################################################################

0 commit comments

Comments
 (0)