@@ -2600,56 +2600,6 @@ def recover_mp_block(block, mp_layers, raw_dtype):
2600
2600
return block
2601
2601
2602
2602
2603
- def get_numel (block , hp_layers ):
2604
- """Get the total number of elements in the specified layers of a block.
2605
-
2606
- Args:
2607
- block (torch.nn.Module): The model block.
2608
- hp_layers (list): List of layer names to include.
2609
-
2610
- Returns:
2611
- int: Total number of elements in the specified layers.
2612
- """
2613
- numel = 0
2614
- for layer_name in hp_layers :
2615
- layer = get_module (block , layer_name )
2616
- numel += layer .weight .numel ()
2617
- return numel
2618
-
2619
-
2620
- def get_best_combination (combination_list , numel_list , loss_list , loss_numel_ratio = 2.0 ):
2621
- """Selects the best combination from a list based on the ranks of two criteria: numel_list and loss_list.
2622
-
2623
- Each combination is ranked by its position in the sorted numel_list and loss_list. The loss rank is scaled
2624
- by `loss_numel_ratio` to adjust its importance relative to the numel rank. The combination with the lowest
2625
- sum of ranks is selected as the best.
2626
-
2627
- Args:
2628
- combination_list (list): List of candidate combinations.
2629
- numel_list (list): List of numerical values representing the size or complexity of each combination.
2630
- loss_list (list): List of loss values associated with each combination.
2631
- loss_numel_ratio (float, optional): Scaling factor for the loss rank relative to the numel rank. Default is 2.0.
2632
-
2633
- Returns:
2634
- The combination from `combination_list` with the lowest combined rank based on numel and loss.
2635
- """
2636
- # Get ranks for numel_list and
2637
- numel_ranks = [sorted (numel_list ).index (x ) for x in numel_list ]
2638
- loss_ranks = [(sorted (loss_list ).index (x )) * loss_numel_ratio for x in loss_list ]
2639
-
2640
- # Calculate rank sums
2641
- rank_sums = [x + y for x , y in zip (numel_ranks , loss_ranks )]
2642
- logger .debug (f"numel_ranks: { numel_ranks } " )
2643
- logger .debug (f"loss_ranks: { loss_ranks } " )
2644
- logger .debug (f"rank sum: { rank_sums } " )
2645
-
2646
- # Find the index of the smallest rank sum
2647
- best_index = rank_sums .index (min (rank_sums ))
2648
-
2649
- # Return the best index
2650
- return best_index
2651
-
2652
-
2653
2603
def get_avg_bits (module ):
2654
2604
"""
2655
2605
Calculates the average number of bits per weight element for supported layers in a given module.
@@ -2699,8 +2649,6 @@ def _generate_recipe(
2699
2649
# special mix-precision configuration
2700
2650
mp_config = {
2701
2651
"mp_ratio" : 1 / 3 ,
2702
- "loss_weight" : 2.0 ,
2703
- "numel_weight" : 1.0 ,
2704
2652
},
2705
2653
):
2706
2654
"""
@@ -2710,7 +2658,7 @@ def _generate_recipe(
2710
2658
mp_dtype (dict, optional): Dictionary specifying the mixed-precision data types for weights and activations.
2711
2659
Defaults to {"data_type": "mx_fp8", "act_data_type": "mx_fp8"}.
2712
2660
mp_config (dict, optional): Dictionary specifying the mixed-precision configuration parameters such as
2713
- ratio, loss weight, and numel weight. Defaults to {"mp_ratio": 1/3, "loss_weight": 2.0, "numel_weight": 1.0 }.
2661
+ ratio, loss weight, and numel weight. Defaults to {"mp_ratio": 1/3}.
2714
2662
2715
2663
Returns:
2716
2664
dict: A dictionary containing the quantization recipe for each layer, excluding the "lm_head" layer.
@@ -2738,29 +2686,20 @@ def _generate_recipe(
2738
2686
def _generate_block_recipe (self , block , block_name , input_ids , q_input_ids , input_others ):
2739
2687
from itertools import combinations
2740
2688
2741
- from auto_round .utils import (
2742
- DTYPE_INFO_MAPPING ,
2743
- create_mp_block ,
2744
- get_best_combination ,
2745
- get_numel ,
2746
- recover_mp_block ,
2747
- )
2748
-
2749
2689
# fetch mix-precision recipe configuration
2750
2690
sample_num = self .recipe_mp_config .get ("sample_num" , 8 )
2751
- mp_ratio = self .recipe_mp_config .get ("mp_ratio" , 1 / 3 )
2752
- loss_weight = float (self .recipe_mp_config .get ("loss_weight" , 2.0 ))
2753
- numel_weight = float (self .recipe_mp_config .get ("numel_weight" , 1.0 ))
2754
- loss_numel_ratio = loss_weight / numel_weight
2755
-
2756
- # calculate the number of layers to use mix-precision
2757
2691
quantizable_layers = [n for n , m in block .named_modules () if isinstance (m , SUPPORTED_LAYER_TYPES )]
2758
- mp_ratio_list = [f"{ i } /{ len (quantizable_layers )} " for i in range (1 , len (quantizable_layers ))]
2759
- quantizable_num = int (mp_ratio * len (quantizable_layers )) # It's ceiling
2760
- logger .warning_once (
2761
- f"[Recipe Mode] { len (quantizable_layers )} layers are detected, so the available mp_ratio values are { mp_ratio_list } "
2762
- )
2763
- logger .warning_once (f"[Recipe Mode] { quantizable_num } layers of each block use the mixed precision." )
2692
+ target_bits = self .recipe_mp_config .get ("target_bits" , None )
2693
+ if target_bits is None :
2694
+ mp_ratio = self .recipe_mp_config .get ("mp_ratio" , 1 / 3 )
2695
+
2696
+ # calculate the number of layers to use mix-precision
2697
+ mp_ratio_list = [f"{ i } /{ len (quantizable_layers )} " for i in range (1 , len (quantizable_layers ))]
2698
+ quantizable_num = int (mp_ratio * len (quantizable_layers )) # It's ceiling
2699
+ logger .warning_once (
2700
+ f"[Recipe Mode] { len (quantizable_layers )} layers are detected, so the available mp_ratio values are { mp_ratio_list } "
2701
+ )
2702
+ logger .warning_once (f"[Recipe Mode] { quantizable_num } layers of each block use the mixed precision." )
2764
2703
# fetch raw low-bits dtype of block for recovering mix-precision block
2765
2704
layer = get_module (block , quantizable_layers [0 ])
2766
2705
raw_dtype = {
@@ -2812,7 +2751,7 @@ def get_loss(q_block, q_input_ids):
2812
2751
total_loss += float (loss )
2813
2752
if is_hpex_available ():
2814
2753
htcore .mark_step ()
2815
- return total_loss
2754
+ return round ( total_loss , 6 )
2816
2755
2817
2756
combination_list = []
2818
2757
avg_bits_list = []
@@ -2831,22 +2770,27 @@ def get_loss(q_block, q_input_ids):
2831
2770
htcore .mark_step ()
2832
2771
logger .debug (f"{ hp_layers } , { loss } , { avg_bits } " )
2833
2772
2834
- # get target hp layers
2835
- best_index = get_best_combination (combination_list , avg_bits_list , loss_list , loss_numel_ratio )
2836
- target_hp_layers , target_avg_bits = combination_list [best_index ], avg_bits_list [best_index ]
2837
- logger .info (f"[Recipe Mode] Recipe results of { block_name } :\n Mix precision layers: { target_hp_layers } ; \n Average bits: { target_avg_bits } ." )
2773
+ # get combination with lowest loss
2774
+ best_loss = float ("inf" )
2775
+ for i , (loss , avg_bits ) in enumerate (zip (loss_list , avg_bits_list )):
2776
+ if best_loss > loss :
2777
+ best_loss = loss
2778
+ best_avg_bits = avg_bits
2779
+ best_combination = combination_list [i ]
2780
+
2781
+ logger .info (f"[Recipe Mode] Recipe results of { block_name } :\n Mix precision layers: { best_combination } ;\n Average bits: { best_avg_bits } ." )
2838
2782
# generate output of quantized block of sample input_ids
2839
- block = create_mp_block (block , target_hp_layers , self .recipe_mp_dtype )
2783
+ block = create_mp_block (block , best_combination , self .recipe_mp_dtype )
2840
2784
q_output = get_output (block , q_input_ids )
2841
- block = recover_mp_block (block , target_hp_layers , raw_dtype )
2785
+ block = recover_mp_block (block , best_combination , raw_dtype )
2842
2786
# update recipe and results
2843
- for layer_name in target_hp_layers :
2787
+ for layer_name in best_combination :
2844
2788
self .recipe_results ["recipe" ].update ({block_name + "." + layer_name : self .recipe_mp_dtype })
2845
2789
self .recipe_results ["results" ].update (
2846
2790
{
2847
2791
block_name : {
2848
- "mp_layers" : target_hp_layers ,
2849
- "bits" : target_avg_bits ,
2792
+ "mp_layers" : best_combination ,
2793
+ "bits" : best_avg_bits ,
2850
2794
}
2851
2795
}
2852
2796
)
0 commit comments