Skip to content

Commit 819fa22

Browse files
committed
add target_bits but not good
Signed-off-by: xinhe3 <[email protected]>
1 parent f02331d commit 819fa22

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

auto_round/utils.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2686,20 +2686,7 @@ def _generate_recipe(
26862686
def _generate_block_recipe(self, block, block_name, input_ids, q_input_ids, input_others):
26872687
from itertools import combinations
26882688

2689-
# fetch mix-precision recipe configuration
2690-
sample_num = self.recipe_mp_config.get("sample_num", 8)
26912689
quantizable_layers = [n for n, m in block.named_modules() if isinstance(m, SUPPORTED_LAYER_TYPES)]
2692-
target_bits = self.recipe_mp_config.get("target_bits", None)
2693-
if target_bits is None:
2694-
mp_ratio = self.recipe_mp_config.get("mp_ratio", 1 / 3)
2695-
2696-
# calculate the number of layers to use mix-precision
2697-
mp_ratio_list = [f"{i}/{len(quantizable_layers)}" for i in range(1, len(quantizable_layers))]
2698-
quantizable_num = int(mp_ratio * len(quantizable_layers)) # It's ceiling
2699-
logger.warning_once(
2700-
f"[Recipe Mode] {len(quantizable_layers)} layers are detected, so the available mp_ratio values are {mp_ratio_list}"
2701-
)
2702-
logger.warning_once(f"[Recipe Mode] {quantizable_num} layers of each block use the mixed precision.")
27032690
# fetch raw low-bits dtype of block for recovering mix-precision block
27042691
layer = get_module(block, quantizable_layers[0])
27052692
raw_dtype = {
@@ -2722,6 +2709,35 @@ def _generate_block_recipe(self, block, block_name, input_ids, q_input_ids, inpu
27222709
}
27232710
)
27242711

2712+
# fetch mix-precision recipe configuration
2713+
sample_num = self.recipe_mp_config.get("sample_num", 8)
2714+
target_bits = self.recipe_mp_config.get("target_bits", None)
2715+
mp_ratio = self.recipe_mp_config.get("mp_ratio", None)
2716+
assert target_bits or mp_ratio, "Either target_bits or mp_ratio should be set in recipe_mp_config."
2717+
if target_bits and mp_ratio:
2718+
logger.warning_once("Both target_bits and mp_ratio are set in recipe_mp_config. target_bits will be used.")
2719+
if target_bits:
2720+
# get the average bits of all combinations
2721+
bits_of_combination = {}
2722+
for quantizable_num in range(len(quantizable_layers)):
2723+
for mp_layers in combinations(quantizable_layers, quantizable_num):
2724+
block = create_mp_block(block, mp_layers, self.recipe_mp_dtype)
2725+
# get average bits
2726+
avg_bits = get_avg_bits(block)
2727+
bits_of_combination[mp_layers] = avg_bits
2728+
block = recover_mp_block(block, mp_layers, raw_dtype)
2729+
acceptable_combination_set = {i for i in bits_of_combination if bits_of_combination[i] <= target_bits}
2730+
else:
2731+
mp_ratio = self.recipe_mp_config.get("mp_ratio", 1 / 3)
2732+
# calculate the number of layers to use mix-precision
2733+
mp_ratio_list = [f"{i}/{len(quantizable_layers)}" for i in range(1, len(quantizable_layers))]
2734+
quantizable_num = int(mp_ratio * len(quantizable_layers)) # It's ceiling
2735+
logger.warning_once(
2736+
f"[Recipe Mode] {len(quantizable_layers)} layers are detected, so the available mp_ratio values are {mp_ratio_list}"
2737+
)
2738+
logger.warning_once(f"[Recipe Mode] {quantizable_num} layers of each block use the mixed precision.")
2739+
acceptable_combination_set = combinations(quantizable_layers, quantizable_num)
2740+
27252741
# generate reference output of sample input_ids
27262742
def get_output(block, input_ids):
27272743
output = self.get_block_outputs(
@@ -2756,19 +2772,31 @@ def get_loss(q_block, q_input_ids):
27562772
combination_list = []
27572773
avg_bits_list = []
27582774
loss_list = []
2759-
for hp_layers in combinations(quantizable_layers, quantizable_num):
2760-
combination_list.append(hp_layers)
2775+
for mp_layers in acceptable_combination_set:
2776+
combination_list.append(mp_layers)
27612777
# get loss
2762-
block = create_mp_block(block, hp_layers, self.recipe_mp_dtype)
2778+
block = create_mp_block(block, mp_layers, self.recipe_mp_dtype)
27632779
# get average bits
27642780
avg_bits = get_avg_bits(block)
27652781
avg_bits_list.append(avg_bits)
27662782
loss = get_loss(block, q_input_ids)
27672783
loss_list.append(loss)
2768-
block = recover_mp_block(block, hp_layers, raw_dtype)
2784+
block = recover_mp_block(block, mp_layers, raw_dtype)
27692785
if is_hpex_available():
27702786
htcore.mark_step()
2771-
logger.debug(f"{hp_layers}, {loss}, {avg_bits}")
2787+
logger.debug(f"{mp_layers}, {loss}, {avg_bits}")
2788+
2789+
# get the worst loss
2790+
block = create_mp_block(block, mp_layers, self.recipe_mp_dtype)
2791+
# get average bits
2792+
avg_bits = get_avg_bits(block)
2793+
avg_bits_list.append(avg_bits)
2794+
loss = get_loss(block, q_input_ids)
2795+
loss_list.append(loss)
2796+
block = recover_mp_block(block, mp_layers, raw_dtype)
2797+
if is_hpex_available():
2798+
htcore.mark_step()
2799+
logger.debug(f"{mp_layers}, {loss}, {avg_bits}")
27722800

27732801
# get combination with lowest loss
27742802
best_loss = float("inf")

workspace/quantize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def initialize_model_and_tokenizer(model_name_or_path):
7272
parser.add_argument("--iters", default=None, type=int, help="iters for autoround.")
7373
parser.add_argument("--seqlen", default=None, type=int, help="sequence length for autoround.")
7474
parser.add_argument("--nsamples", default=None, type=int, help="number of samples for autoround.")
75-
parser.add_argument("--target_bits", default=5, type=float, help="number of samples for autoround.")
7675
parser.add_argument("--target_loss_ratio", default=1.2, type=float, help="number of samples for autoround.")
7776
parser.add_argument(
7877
"--use_hpu_graph", action="store_true", help="whether to use hpu graph mode to accelerate performance"
@@ -83,6 +82,7 @@ def initialize_model_and_tokenizer(model_name_or_path):
8382
parser.add_argument(
8483
"--disable_optimum_habana", action="store_true", help="whether to use adapt_transformers_to_gaudi"
8584
)
85+
parser.add_argument("--target_bits", default=5, type=float, help="number of samples for autoround.")
8686
parser.add_argument("--mp_ratio", default="1/3", type=str, help="number of samples for autoround.")
8787
parser.add_argument("--save", action="store_true", help="whether to save the quantized model")
8888
parser.add_argument("--load", action="store_true", help="whether to load the quantized model")
@@ -226,6 +226,7 @@ def match_pattern(name, pattern):
226226

227227
recipe_results = autoround._generate_recipe(
228228
mp_config={
229+
# "target_bits": float(args.target_bits),
229230
"mp_ratio": float(eval(args.mp_ratio)),
230231
},
231232
)

0 commit comments

Comments
 (0)