Skip to content

Commit 7a36ccc

Browse files
Added support to export for BF16 weight and amax for vLLM fakequant QAT (#579)
## What does this PR do? **Type of change:** New Feature **Overview:** Support for vLLM fakequantize QAT/QAD checkpoint evaluation. This MR adds function to export checkpoint as BF16 weights and amax using `export_hf_checkpoint` for HF and `export_mcore_gpt_to_hf` for MCore using `export_bf16_weights_amax` option. The exported weights and amax can be used with [vllm_serve_fakequant.py](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/096ee13ea62bbb0ce0a4e4128c439651374d6235/examples/vllm_serve/vllm_serve_fakequant.py) script to run saved checkpoint. ## Usage Refer to [README.md](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/096ee13ea62bbb0ce0a4e4128c439651374d6235/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip) ## Testing - Tested HF approach by exporting bf16 model using QAT script and running vllm server, verified amax values match - Tested MCore approach by quantizing and exporting bf16 model using quantize.sh and export.sh script and running vllm server, verified amax values match ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes ## Additional Information MCore export script doesn't have the option to export enable currently --------- Signed-off-by: Kinjal Patel <[email protected]>
1 parent 5842d73 commit 7a36ccc

File tree

9 files changed

+547
-240
lines changed

9 files changed

+547
-240
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Model Optimizer Changelog (Linux)
1717
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
1818
- Add support for PyTorch Geometric quantization.
1919
- Add per tensor and per channel MSE calibrator support.
20+
- Added support for PTQ/QAT checkpoint export and loading for running fakequant evaluation in vLLM. See ``examples/vllm_serve/README.md#load-qatptq-model-and-serve-in-vllm-wip`` for more details.
2021

2122
**Documentation**
2223

examples/vllm_serve/README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,19 @@ lm_eval --model local-completions --tasks gsm8k --model_args model=<model_name>,
5555

5656
## Load QAT/PTQ model and serve in vLLM (WIP)
5757

58-
Overwrite the calibrated amax value with prepared values from either PTQ/QAT. This is only tested for Llama3.1
58+
Overwrite the calibrated amax value with prepared values from either QAT/PTQ.
5959

60-
Step 1: convert amax to merged amax, using llama3.1 as an example:
60+
Step 1: export the model with bf16 weights and amax values.
61+
62+
- For HF model set `export_bf16_weights_amax` to export the model with function `modelopt.torch.export.unified_export_hf.export_hf_checkpoint`.
63+
- For MCore model use `export_bf16_weights_amax` to export the model with function `modelopt.torch.export.unified_export_megatron.export_mcore_gpt_to_hf`.
64+
65+
Step 2: configure <quant_amax.pth> from exported model using AMAX_FILE_PATH environment variable in step 1. For example:
6166

6267
```bash
63-
python convert_amax_hf2vllm.py -i <amax.pth> -o <vllm_amax.pth>
68+
AMAX_FILE_PATH=<vllm_amax.pth> QUANT_CFG=<quant_config> python vllm_serve_fakequant.py <model_path> -tp 8 --host 0.0.0.0 --port 8000
6469
```
6570

66-
Step 2: add `<vllm_amax.pth>` to `quant_config` in `vllm_serve_fakequant.py`
67-
6871
## Important Notes
6972

7073
**Amax Synchronization across Tensor Parallel (TP):**
@@ -85,3 +88,5 @@ torch.distributed.barrier()
8588
## Known Problems
8689

8790
1. AWQ is not yet supported in vLLM.
91+
2. PTQ/QAT checkpoint doesn't work with KV Cache quantization enabled.
92+
3. Mixed precision checkpoint doesn't work currently.

examples/vllm_serve/convert_amax_hf2vllm.py

Lines changed: 0 additions & 213 deletions
This file was deleted.

examples/vllm_serve/fakequant_worker.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
import dataclasses
1717
import os
18+
import re
1819
import warnings
20+
from collections import defaultdict
1921
from contextlib import contextmanager
2022
from typing import Any
2123

@@ -30,6 +32,99 @@
3032
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
3133

3234

35+
def convert_amax_hf2vllm(
36+
hf_state_dict: dict[str, torch.Tensor], fuse_experts: bool = False
37+
) -> dict[str, torch.Tensor]:
38+
"""
39+
Convert amax values from HuggingFace format to vLLM format.
40+
41+
This function merges:
42+
- q_proj, k_proj, v_proj amax values into qkv_proj (taking max)
43+
- gate_proj, up_proj amax values into gate_up_proj (taking max)
44+
45+
Args:
46+
hf_state_dict: HuggingFace state dict containing amax values
47+
48+
Returns:
49+
vLLM format state dict with merged amax values
50+
"""
51+
vllm_state_dict = {}
52+
53+
# Group keys by their base pattern (without the specific projection name)
54+
merge_groups = defaultdict(list)
55+
56+
for key, value in hf_state_dict.items():
57+
if "_amax" not in key:
58+
# Copy non-amax keys as-is
59+
vllm_state_dict[key] = value
60+
continue
61+
62+
# Check if this is a q/k/v projection that needs merging
63+
qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key)
64+
if qkv_match:
65+
base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3)
66+
merge_groups[base_pattern].append((key, value))
67+
continue
68+
69+
# Check if this is an expert gate/up projection
70+
# Pattern: model.layers.0.mlp.experts.*.gate_proj.input_quantizer._amax and
71+
# model.layers.0.mlp.experts.*.up_proj.input_quantizer._amax
72+
# Maps to: model.layers.0.mlp.experts.w13_input_quantizer._amax
73+
expert_gate_up_match = (
74+
"mixer" not in key
75+
and fuse_experts
76+
and re.search(r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer\._amax)$", key)
77+
)
78+
if expert_gate_up_match:
79+
base_pattern = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3)
80+
merge_groups[base_pattern].append((key, value))
81+
continue
82+
83+
# Check if this is a non-expert gate/up projection that needs merging
84+
gate_up_match = (
85+
"mixer" not in key
86+
and "experts" not in key
87+
and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key)
88+
)
89+
if gate_up_match:
90+
base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3)
91+
merge_groups[base_pattern].append((key, value))
92+
continue
93+
94+
# Check if this is an expert down_proj
95+
# Pattern: model.layers.0.mlp.experts.*.down_proj.input_quantizer._amax
96+
# Maps to: model.layers.0.mlp.experts.w2_input_quantizer._amax
97+
expert_down_match = (
98+
"mixer" not in key
99+
and fuse_experts
100+
and re.search(r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer\._amax)$", key)
101+
)
102+
if expert_down_match:
103+
base_pattern = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2)
104+
merge_groups[base_pattern].append((key, value))
105+
continue
106+
107+
# Copy other amax keys as-is (like o_proj, down_proj)
108+
vllm_state_dict[key] = value
109+
110+
# Merge grouped amax values by taking the maximum
111+
for merged_key, key_value_pairs in merge_groups.items():
112+
if len(key_value_pairs) > 1:
113+
# Take the maximum across all values for this merged key
114+
values = [value for _, value in key_value_pairs]
115+
merged_value = torch.stack(values).max(dim=0)[0]
116+
vllm_state_dict[merged_key] = merged_value
117+
print(f"Merged {len(key_value_pairs)} keys into {merged_key}")
118+
for orig_key, _ in key_value_pairs:
119+
print(f" - {orig_key}")
120+
else:
121+
# Single key, just rename it
122+
_, value = key_value_pairs[0]
123+
vllm_state_dict[merged_key] = value
124+
125+
return vllm_state_dict
126+
127+
33128
@contextmanager
34129
def disable_compilation(model):
35130
do_not_compile = True
@@ -154,8 +249,17 @@ def calibrate_loop(model: Any = None) -> None:
154249
if amax_file_path:
155250
print(f"Loading amax values from {amax_file_path}")
156251
saved_amax_dict = torch.load(amax_file_path)
157-
current_state_dict = model.state_dict()
252+
# convert amax keys to vLLM format
253+
if hasattr(self.model_runner.model, "hf_to_vllm_mapper"):
254+
saved_amax_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict)
255+
saved_amax_dict = {
256+
key.replace("quantizer_amax", "quantizer._amax"): value
257+
for key, value in saved_amax_dict.items()
258+
if key.endswith("quantizer_amax")
259+
}
260+
saved_amax_dict = convert_amax_hf2vllm(saved_amax_dict, fuse_experts=True)
158261

262+
current_state_dict = model.state_dict()
159263
# Count amax keys in checkpoint and model
160264
checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")]
161265
model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")]

0 commit comments

Comments
 (0)