Skip to content

Commit f02331d

Browse files
committed
use loss only
Signed-off-by: xinhe3 <[email protected]>
1 parent 2a0e0b8 commit f02331d

File tree

2 files changed

+38
-91
lines changed

2 files changed

+38
-91
lines changed

auto_round/utils.py

Lines changed: 27 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -2600,56 +2600,6 @@ def recover_mp_block(block, mp_layers, raw_dtype):
26002600
return block
26012601

26022602

2603-
def get_numel(block, hp_layers):
2604-
"""Get the total number of elements in the specified layers of a block.
2605-
2606-
Args:
2607-
block (torch.nn.Module): The model block.
2608-
hp_layers (list): List of layer names to include.
2609-
2610-
Returns:
2611-
int: Total number of elements in the specified layers.
2612-
"""
2613-
numel = 0
2614-
for layer_name in hp_layers:
2615-
layer = get_module(block, layer_name)
2616-
numel += layer.weight.numel()
2617-
return numel
2618-
2619-
2620-
def get_best_combination(combination_list, numel_list, loss_list, loss_numel_ratio=2.0):
2621-
"""Selects the best combination from a list based on the ranks of two criteria: numel_list and loss_list.
2622-
2623-
Each combination is ranked by its position in the sorted numel_list and loss_list. The loss rank is scaled
2624-
by `loss_numel_ratio` to adjust its importance relative to the numel rank. The combination with the lowest
2625-
sum of ranks is selected as the best.
2626-
2627-
Args:
2628-
combination_list (list): List of candidate combinations.
2629-
numel_list (list): List of numerical values representing the size or complexity of each combination.
2630-
loss_list (list): List of loss values associated with each combination.
2631-
loss_numel_ratio (float, optional): Scaling factor for the loss rank relative to the numel rank. Default is 2.0.
2632-
2633-
Returns:
2634-
The combination from `combination_list` with the lowest combined rank based on numel and loss.
2635-
"""
2636-
# Get ranks for numel_list and
2637-
numel_ranks = [sorted(numel_list).index(x) for x in numel_list]
2638-
loss_ranks = [(sorted(loss_list).index(x)) * loss_numel_ratio for x in loss_list]
2639-
2640-
# Calculate rank sums
2641-
rank_sums = [x + y for x, y in zip(numel_ranks, loss_ranks)]
2642-
logger.debug(f"numel_ranks: {numel_ranks}")
2643-
logger.debug(f"loss_ranks: {loss_ranks}")
2644-
logger.debug(f"rank sum: {rank_sums}")
2645-
2646-
# Find the index of the smallest rank sum
2647-
best_index = rank_sums.index(min(rank_sums))
2648-
2649-
# Return the best index
2650-
return best_index
2651-
2652-
26532603
def get_avg_bits(module):
26542604
"""
26552605
Calculates the average number of bits per weight element for supported layers in a given module.
@@ -2699,8 +2649,6 @@ def _generate_recipe(
26992649
# special mix-precision configuration
27002650
mp_config={
27012651
"mp_ratio": 1 / 3,
2702-
"loss_weight": 2.0,
2703-
"numel_weight": 1.0,
27042652
},
27052653
):
27062654
"""
@@ -2710,7 +2658,7 @@ def _generate_recipe(
27102658
mp_dtype (dict, optional): Dictionary specifying the mixed-precision data types for weights and activations.
27112659
Defaults to {"data_type": "mx_fp8", "act_data_type": "mx_fp8"}.
27122660
mp_config (dict, optional): Dictionary specifying the mixed-precision configuration parameters such as
2713-
ratio, loss weight, and numel weight. Defaults to {"mp_ratio": 1/3, "loss_weight": 2.0, "numel_weight": 1.0}.
2661+
ratio, loss weight, and numel weight. Defaults to {"mp_ratio": 1/3}.
27142662
27152663
Returns:
27162664
dict: A dictionary containing the quantization recipe for each layer, excluding the "lm_head" layer.
@@ -2738,29 +2686,20 @@ def _generate_recipe(
27382686
def _generate_block_recipe(self, block, block_name, input_ids, q_input_ids, input_others):
27392687
from itertools import combinations
27402688

2741-
from auto_round.utils import (
2742-
DTYPE_INFO_MAPPING,
2743-
create_mp_block,
2744-
get_best_combination,
2745-
get_numel,
2746-
recover_mp_block,
2747-
)
2748-
27492689
# fetch mix-precision recipe configuration
27502690
sample_num = self.recipe_mp_config.get("sample_num", 8)
2751-
mp_ratio = self.recipe_mp_config.get("mp_ratio", 1 / 3)
2752-
loss_weight = float(self.recipe_mp_config.get("loss_weight", 2.0))
2753-
numel_weight = float(self.recipe_mp_config.get("numel_weight", 1.0))
2754-
loss_numel_ratio = loss_weight / numel_weight
2755-
2756-
# calculate the number of layers to use mix-precision
27572691
quantizable_layers = [n for n, m in block.named_modules() if isinstance(m, SUPPORTED_LAYER_TYPES)]
2758-
mp_ratio_list = [f"{i}/{len(quantizable_layers)}" for i in range(1, len(quantizable_layers))]
2759-
quantizable_num = int(mp_ratio * len(quantizable_layers)) # It's ceiling
2760-
logger.warning_once(
2761-
f"[Recipe Mode] {len(quantizable_layers)} layers are detected, so the available mp_ratio values are {mp_ratio_list}"
2762-
)
2763-
logger.warning_once(f"[Recipe Mode] {quantizable_num} layers of each block use the mixed precision.")
2692+
target_bits = self.recipe_mp_config.get("target_bits", None)
2693+
if target_bits is None:
2694+
mp_ratio = self.recipe_mp_config.get("mp_ratio", 1 / 3)
2695+
2696+
# calculate the number of layers to use mix-precision
2697+
mp_ratio_list = [f"{i}/{len(quantizable_layers)}" for i in range(1, len(quantizable_layers))]
2698+
quantizable_num = int(mp_ratio * len(quantizable_layers)) # It's ceiling
2699+
logger.warning_once(
2700+
f"[Recipe Mode] {len(quantizable_layers)} layers are detected, so the available mp_ratio values are {mp_ratio_list}"
2701+
)
2702+
logger.warning_once(f"[Recipe Mode] {quantizable_num} layers of each block use the mixed precision.")
27642703
# fetch raw low-bits dtype of block for recovering mix-precision block
27652704
layer = get_module(block, quantizable_layers[0])
27662705
raw_dtype = {
@@ -2812,7 +2751,7 @@ def get_loss(q_block, q_input_ids):
28122751
total_loss += float(loss)
28132752
if is_hpex_available():
28142753
htcore.mark_step()
2815-
return total_loss
2754+
return round(total_loss, 6)
28162755

28172756
combination_list = []
28182757
avg_bits_list = []
@@ -2831,22 +2770,27 @@ def get_loss(q_block, q_input_ids):
28312770
htcore.mark_step()
28322771
logger.debug(f"{hp_layers}, {loss}, {avg_bits}")
28332772

2834-
# get target hp layers
2835-
best_index = get_best_combination(combination_list, avg_bits_list, loss_list, loss_numel_ratio)
2836-
target_hp_layers, target_avg_bits = combination_list[best_index], avg_bits_list[best_index]
2837-
logger.info(f"[Recipe Mode] Recipe results of {block_name}:\n Mix precision layers: {target_hp_layers}; \nAverage bits: {target_avg_bits}.")
2773+
# get combination with lowest loss
2774+
best_loss = float("inf")
2775+
for i, (loss, avg_bits) in enumerate(zip(loss_list, avg_bits_list)):
2776+
if best_loss > loss:
2777+
best_loss = loss
2778+
best_avg_bits = avg_bits
2779+
best_combination = combination_list[i]
2780+
2781+
logger.info(f"[Recipe Mode] Recipe results of {block_name}:\nMix precision layers: {best_combination};\nAverage bits: {best_avg_bits}.")
28382782
# generate output of quantized block of sample input_ids
2839-
block = create_mp_block(block, target_hp_layers, self.recipe_mp_dtype)
2783+
block = create_mp_block(block, best_combination, self.recipe_mp_dtype)
28402784
q_output = get_output(block, q_input_ids)
2841-
block = recover_mp_block(block, target_hp_layers, raw_dtype)
2785+
block = recover_mp_block(block, best_combination, raw_dtype)
28422786
# update recipe and results
2843-
for layer_name in target_hp_layers:
2787+
for layer_name in best_combination:
28442788
self.recipe_results["recipe"].update({block_name + "." + layer_name: self.recipe_mp_dtype})
28452789
self.recipe_results["results"].update(
28462790
{
28472791
block_name: {
2848-
"mp_layers": target_hp_layers,
2849-
"bits": target_avg_bits,
2792+
"mp_layers": best_combination,
2793+
"bits": best_avg_bits,
28502794
}
28512795
}
28522796
)

workspace/README.md

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
11

22
```bash
33
############################### Gaudi model path #############################################
4-
deepspeed --include="localhost:0,1,2,3" --master_port=29500 quantize.py --autoround --batch_size 16 --accuracy --dtype mx_fp4 --mp_ratio 4/7 2>&1 |tee mxfp4_ratio_2_8b.log
4+
deepspeed --include="localhost:2,3" --master_port=29520 quantize.py --autoround --batch_size 16 --accuracy --dtype mx_fp4 --mp_ratio 5/7 2>&1 |tee mxfp4_op_5_8b.log
55

6-
deepspeed --include="localhost:4,5,6,7" --master_port=29510 quantize.py --autoround --batch_size 16 --accuracy --dtype mx_fp4 --mp_ratio 3/7 2>&1 |tee mxfp4_ratio_3_8b.log
6+
deepspeed --include="localhost:4,5,6,7" --master_port=29510 quantize.py --autoround --batch_size 16 --accuracy --dtype mx_fp4 --mp_ratio 3/7 2>&1 |tee mxfp4_op_3_8b.log
77

8-
deepspeed --include="localhost:0,1,2,3,4,5,6,7" --master_port=29510 quantize.py --model_name_or_path /git_lfs/data/pytorch/llama3.3/Meta-Llama-3.3-70B-Instruct/ --batch_size 8 --accuracy --dtype mx_fp4 --mp_ratio 4/7 --autoround 2>&1 |tee mxfp4_op_4_3.3_70b.log
8+
deepspeed --include="localhost:4,5,6,7" --master_port=29510 quantize.py --model_name_or_path /git_lfs/data/pytorch/llama3.3/Meta-Llama-3.3-70B-Instruct/ --batch_size 8 --accuracy --dtype mx_fp4 --mp_ratio 4/7 --autoround 2>&1 |tee mxfp4_op_4_3.3_70b.log
99

10-
deepspeed --include="localhost:0,1,2,3,4,5,6,7" --master_port=29510 quantize.py --model_name_or_path /git_lfs/data/pytorch/llama3.3/Meta-Llama-3.3-70B-Instruct/ --batch_size 8 --accuracy --dtype mx_fp4 --mp_ratio 5/7 --autoround 2>&1 |tee mxfp4_op_5_3.3_70b.log
10+
deepspeed --include="localhost:4,5,6,7" --master_port=29510 quantize.py --model_name_or_path /git_lfs/data/pytorch/llama3.3/Meta-Llama-3.3-70B-Instruct/ --batch_size 8 --accuracy --dtype mx_fp4 --mp_ratio 5/7 --autoround 2>&1 |tee mxfp4_op_5_3.3_70b.log
1111

1212

1313
############################### H20 model path #############################################
14-
deepspeed --include="localhost:4,5,6,7" --master_port=29500 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.1-8B-Instruct/ --autoround --batch_size 64 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 3/7 2>&1 |tee mxfp4_ratio_3_8b.log
14+
deepspeed --include="localhost:4,5,6,7" --master_port=29500 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.1-8B-Instruct/ --autoround --batch_size 64 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 3/7 2>&1 |tee mxfp4_op_3_8b.log
1515

16-
deepspeed --include="localhost:2,3" --master_port=29510 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.1-8B-Instruct/ --autoround --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 3/7 2>&1 |tee mxfp4_ratio_3_8b.log
16+
deepspeed --include="localhost:2,3" --master_port=29520 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.1-8B-Instruct/ --autoround --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 5/7 2>&1 |tee mxfp4_op_5_8b.log
1717

18+
deepspeed --include="localhost:2,3" --master_port=29520 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.1-8B-Instruct/ --autoround --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 6/7 2>&1 |tee mxfp4_op_6_8b.log
1819

19-
deepspeed --include="localhost:0,1,2,3" --master_port=29500 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.3-70B-Instruct/ --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 4/7 --autoround 2>&1 |tee mxfp4_ratio_2_3.3_70b.log
20+
deepspeed --include="localhost:0,1,2,3" --master_port=29500 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.3-70B-Instruct/ --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 5/7 --autoround 2>&1 |tee mxfp4_op_5_3.3_70b.log
2021

22+
deepspeed --include="localhost:4,5,6,7" --master_port=29510 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.3-70B-Instruct/ --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 3/7 --autoround 2>&1 |tee mxfp4_op_3_3.3_70b.log
2123

22-
deepspeed --include="localhost:4,5,6,7" --master_port=29510 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.3-70B-Instruct/ --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 3/7 --autoround 2>&1 |tee mxfp4_ratio_3_3.3_70b.log
24+
H20-2-1
25+
deepspeed --include="localhost:4,5,6,7" --master_port=29510 quantize.py --model_name_or_path /ssd/xinhe/Llama-3.3-70B-Instruct/ --batch_size 32 --accuracy --dtype mx_fp4 --device cuda --mp_ratio 4/7 --autoround 2>&1 |tee mxfp4_op_4_3.3_70b.log
2326

2427

2528
```

0 commit comments

Comments
 (0)