@@ -2686,20 +2686,7 @@ def _generate_recipe(
2686
2686
def _generate_block_recipe (self , block , block_name , input_ids , q_input_ids , input_others ):
2687
2687
from itertools import combinations
2688
2688
2689
- # fetch mix-precision recipe configuration
2690
- sample_num = self .recipe_mp_config .get ("sample_num" , 8 )
2691
2689
quantizable_layers = [n for n , m in block .named_modules () if isinstance (m , SUPPORTED_LAYER_TYPES )]
2692
- target_bits = self .recipe_mp_config .get ("target_bits" , None )
2693
- if target_bits is None :
2694
- mp_ratio = self .recipe_mp_config .get ("mp_ratio" , 1 / 3 )
2695
-
2696
- # calculate the number of layers to use mix-precision
2697
- mp_ratio_list = [f"{ i } /{ len (quantizable_layers )} " for i in range (1 , len (quantizable_layers ))]
2698
- quantizable_num = int (mp_ratio * len (quantizable_layers )) # It's ceiling
2699
- logger .warning_once (
2700
- f"[Recipe Mode] { len (quantizable_layers )} layers are detected, so the available mp_ratio values are { mp_ratio_list } "
2701
- )
2702
- logger .warning_once (f"[Recipe Mode] { quantizable_num } layers of each block use the mixed precision." )
2703
2690
# fetch raw low-bits dtype of block for recovering mix-precision block
2704
2691
layer = get_module (block , quantizable_layers [0 ])
2705
2692
raw_dtype = {
@@ -2722,6 +2709,35 @@ def _generate_block_recipe(self, block, block_name, input_ids, q_input_ids, inpu
2722
2709
}
2723
2710
)
2724
2711
2712
+ # fetch mix-precision recipe configuration
2713
+ sample_num = self .recipe_mp_config .get ("sample_num" , 8 )
2714
+ target_bits = self .recipe_mp_config .get ("target_bits" , None )
2715
+ mp_ratio = self .recipe_mp_config .get ("mp_ratio" , None )
2716
+ assert target_bits or mp_ratio , "Either target_bits or mp_ratio should be set in recipe_mp_config."
2717
+ if target_bits and mp_ratio :
2718
+ logger .warning_once ("Both target_bits and mp_ratio are set in recipe_mp_config. target_bits will be used." )
2719
+ if target_bits :
2720
+ # get the average bits of all combinations
2721
+ bits_of_combination = {}
2722
+ for quantizable_num in range (len (quantizable_layers )):
2723
+ for mp_layers in combinations (quantizable_layers , quantizable_num ):
2724
+ block = create_mp_block (block , mp_layers , self .recipe_mp_dtype )
2725
+ # get average bits
2726
+ avg_bits = get_avg_bits (block )
2727
+ bits_of_combination [mp_layers ] = avg_bits
2728
+ block = recover_mp_block (block , mp_layers , raw_dtype )
2729
+ acceptable_combination_set = {i for i in bits_of_combination if bits_of_combination [i ] <= target_bits }
2730
+ else :
2731
+ mp_ratio = self .recipe_mp_config .get ("mp_ratio" , 1 / 3 )
2732
+ # calculate the number of layers to use mix-precision
2733
+ mp_ratio_list = [f"{ i } /{ len (quantizable_layers )} " for i in range (1 , len (quantizable_layers ))]
2734
+ quantizable_num = int (mp_ratio * len (quantizable_layers )) # It's ceiling
2735
+ logger .warning_once (
2736
+ f"[Recipe Mode] { len (quantizable_layers )} layers are detected, so the available mp_ratio values are { mp_ratio_list } "
2737
+ )
2738
+ logger .warning_once (f"[Recipe Mode] { quantizable_num } layers of each block use the mixed precision." )
2739
+ acceptable_combination_set = combinations (quantizable_layers , quantizable_num )
2740
+
2725
2741
# generate reference output of sample input_ids
2726
2742
def get_output (block , input_ids ):
2727
2743
output = self .get_block_outputs (
@@ -2756,19 +2772,31 @@ def get_loss(q_block, q_input_ids):
2756
2772
combination_list = []
2757
2773
avg_bits_list = []
2758
2774
loss_list = []
2759
- for hp_layers in combinations ( quantizable_layers , quantizable_num ) :
2760
- combination_list .append (hp_layers )
2775
+ for mp_layers in acceptable_combination_set :
2776
+ combination_list .append (mp_layers )
2761
2777
# get loss
2762
- block = create_mp_block (block , hp_layers , self .recipe_mp_dtype )
2778
+ block = create_mp_block (block , mp_layers , self .recipe_mp_dtype )
2763
2779
# get average bits
2764
2780
avg_bits = get_avg_bits (block )
2765
2781
avg_bits_list .append (avg_bits )
2766
2782
loss = get_loss (block , q_input_ids )
2767
2783
loss_list .append (loss )
2768
- block = recover_mp_block (block , hp_layers , raw_dtype )
2784
+ block = recover_mp_block (block , mp_layers , raw_dtype )
2769
2785
if is_hpex_available ():
2770
2786
htcore .mark_step ()
2771
- logger .debug (f"{ hp_layers } , { loss } , { avg_bits } " )
2787
+ logger .debug (f"{ mp_layers } , { loss } , { avg_bits } " )
2788
+
2789
+ # get the worst loss
2790
+ block = create_mp_block (block , mp_layers , self .recipe_mp_dtype )
2791
+ # get average bits
2792
+ avg_bits = get_avg_bits (block )
2793
+ avg_bits_list .append (avg_bits )
2794
+ loss = get_loss (block , q_input_ids )
2795
+ loss_list .append (loss )
2796
+ block = recover_mp_block (block , mp_layers , raw_dtype )
2797
+ if is_hpex_available ():
2798
+ htcore .mark_step ()
2799
+ logger .debug (f"{ mp_layers } , { loss } , { avg_bits } " )
2772
2800
2773
2801
# get combination with lowest loss
2774
2802
best_loss = float ("inf" )
0 commit comments