Skip to content

Commit f46e41d

Browse files
committed
Added support for Qwen3-MoE
Signed-off-by: Kinjal Patel <[email protected]>
1 parent b0f78c8 commit f46e41d

File tree

3 files changed

+87
-49
lines changed

3 files changed

+87
-49
lines changed

examples/vllm_serve/fakequant_worker.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
def convert_amax_hf2vllm(
36-
hf_state_dict: dict[str, torch.Tensor],
36+
hf_state_dict: dict[str, torch.Tensor], fuse_experts: bool = False
3737
) -> dict[str, torch.Tensor]:
3838
"""
3939
Convert amax values from HuggingFace format to vLLM format.
@@ -66,13 +66,44 @@ def convert_amax_hf2vllm(
6666
merge_groups[base_pattern].append((key, value))
6767
continue
6868

69-
# Check if this is a gate/up projection that needs merging
70-
gate_up_match = "mixer" not in key and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key)
69+
# Check if this is an expert gate/up projection
70+
# Pattern: model.layers.0.mlp.experts.*.gate_proj.input_quantizer._amax and
71+
# model.layers.0.mlp.experts.*.up_proj.input_quantizer._amax
72+
# Maps to: model.layers.0.mlp.experts.w13_input_quantizer._amax
73+
expert_gate_up_match = (
74+
"mixer" not in key
75+
and fuse_experts
76+
and re.search(r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer\._amax)$", key)
77+
)
78+
if expert_gate_up_match:
79+
base_pattern = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3)
80+
merge_groups[base_pattern].append((key, value))
81+
continue
82+
83+
# Check if this is a non-expert gate/up projection that needs merging
84+
gate_up_match = (
85+
"mixer" not in key
86+
and "experts" not in key
87+
and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key)
88+
)
7189
if gate_up_match:
7290
base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3)
7391
merge_groups[base_pattern].append((key, value))
7492
continue
7593

94+
# Check if this is an expert down_proj
95+
# Pattern: model.layers.0.mlp.experts.*.down_proj.input_quantizer._amax
96+
# Maps to: model.layers.0.mlp.experts.w2_input_quantizer._amax
97+
expert_down_match = (
98+
"mixer" not in key
99+
and fuse_experts
100+
and re.search(r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer\._amax)$", key)
101+
)
102+
if expert_down_match:
103+
base_pattern = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2)
104+
merge_groups[base_pattern].append((key, value))
105+
continue
106+
76107
# Copy other amax keys as-is (like o_proj, down_proj)
77108
vllm_state_dict[key] = value
78109

@@ -226,7 +257,7 @@ def calibrate_loop(model: Any = None) -> None:
226257
for key, value in saved_amax_dict.items()
227258
if key.endswith("quantizer_amax")
228259
}
229-
saved_amax_dict = convert_amax_hf2vllm(saved_amax_dict)
260+
saved_amax_dict = convert_amax_hf2vllm(saved_amax_dict, fuse_experts=True)
230261

231262
current_state_dict = model.state_dict()
232263
# Count amax keys in checkpoint and model

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,21 @@
2121
import vllm.model_executor.layers.fused_moe.layer as vllm_fused_moe_layer
2222
import vllm.model_executor.layers.linear as vllm_linear
2323

24-
try:
25-
import vllm.model_executor.layers.fused_moe.shared_fused_moe as vllm_shared_fused_moe_layer
26-
except ImportError:
27-
vllm_shared_fused_moe_layer = None
28-
2924
from ...utils.distributed import ParallelState
3025
from ..nn import QuantLinearConvBase, QuantModule, QuantModuleRegistry, TensorQuantizer
3126

27+
# Try multiple import paths for vLLM compatibility across versions
28+
vllm_shared_fused_moe_layer = None
29+
for module_path in [
30+
"vllm.model_executor.layers.fused_moe.shared_fused_moe", # 0.11.0+
31+
"vllm.model_executor.layers.shared_fused_moe.shared_fused_moe", # 0.10.2
32+
]:
33+
try:
34+
vllm_shared_fused_moe_layer = importlib.import_module(module_path)
35+
break
36+
except ImportError:
37+
continue
38+
3239
vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe")
3340

3441

tests/gpu/torch/export/test_vllm_fakequant_export.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,94 +13,97 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import pytest
17-
import torch
16+
import json
1817
from copy import deepcopy
1918
from functools import partial
20-
import modelopt.torch.quantization as mtq
21-
from modelopt.torch.export.unified_export_hf import export_hf_checkpoint
22-
from modelopt.torch.export.unified_export_megatron import export_mcore_gpt_to_hf
23-
from _test_utils.torch.transformers_models import create_tiny_llama_dir
19+
20+
import pytest
21+
import torch
22+
from _test_utils.import_helper import skip_if_no_megatron
2423
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
2524
from _test_utils.torch.megatron.models import get_mcore_gpt_model
26-
from _test_utils.import_helper import skip_if_no_megatron
25+
from _test_utils.torch.transformers_models import create_tiny_llama_dir
2726
from transformers import AutoModelForCausalLM
2827

29-
import os
30-
import json
28+
import modelopt.torch.quantization as mtq
29+
from modelopt.torch.export.unified_export_hf import export_hf_checkpoint
30+
from modelopt.torch.export.unified_export_megatron import export_mcore_gpt_to_hf
3131

3232
skip_if_no_megatron(apex_or_te_required=True)
3333

34+
3435
@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG])
3536
def test_hf_vllm_export(tmp_path, quant_cfg):
3637
"""Test HuggingFace model export for vLLM with fake quantization.
37-
38+
3839
This test verifies:
3940
1. Model weights match before and after export
4041
2. quant_amax.pth file is created, huggingface config file does not exist
4142
3. Amax values are correctly extracted and saved in quant_amax.pth file
4243
"""
43-
44+
4445
# Create a tiny LLaMA model for testing
4546
tiny_model_dir = create_tiny_llama_dir(tmp_path, with_tokenizer=True, num_hidden_layers=2)
46-
47+
4748
# Load the model
4849
model = AutoModelForCausalLM.from_pretrained(tiny_model_dir)
4950
model = model.cuda()
5051
model.eval()
51-
52+
5253
# Quantize the model
5354
def forward_loop(model):
5455
input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda()
5556
with torch.no_grad():
5657
model(input_ids)
57-
58+
5859
model = mtq.quantize(model, quant_cfg, forward_loop)
59-
60+
6061
model_state_dict = deepcopy(model.state_dict())
6162

6263
# Export directory
6364
export_dir = tmp_path / "vllm_export"
6465
export_dir.mkdir(exist_ok=True)
65-
66+
6667
# Export for vLLM
6768
export_hf_checkpoint(model, export_dir=export_dir, export_vllm_fq_weights_qstate=True)
6869

6970
# check if quant_amax.pth file exists
7071
quant_amax_file = export_dir / "quant_amax.pth"
7172
assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}"
72-
73+
7374
# make sure hf_quant_config.json file does not exist
7475
hf_quant_config_file = export_dir / "hf_quant_config.json"
75-
assert not hf_quant_config_file.exists(), f"hf_quant_config.json file should not be created in {export_dir}"
76+
assert not hf_quant_config_file.exists(), (
77+
f"hf_quant_config.json file should not be created in {export_dir}"
78+
)
7679

7780
# check weights match before and after export
7881
model_after = AutoModelForCausalLM.from_pretrained(export_dir)
7982
model_after = model_after.cuda()
8083
model_after.eval()
8184
model_after_state_dict = model_after.state_dict()
8285
amax_state_dict = {}
83-
for key in model_state_dict.keys():
86+
for key, param in model_state_dict.items():
8487
if key.endswith("_amax"):
85-
amax_state_dict[key] = model_state_dict[key]
88+
amax_state_dict[key] = param
8689
continue
87-
88-
assert torch.allclose(model_state_dict[key], model_after_state_dict[key], atol=1e-6), (
90+
91+
assert torch.allclose(param, model_after_state_dict[key], atol=1e-6), (
8992
f"Weight mismatch for {key}: "
90-
f"before shape={model_state_dict[key].shape}, after shape={model_after_state_dict[key].shape}, "
91-
f"max diff={torch.abs(model_state_dict[key] - model_after_state_dict[key]).max()}"
93+
f"before shape={param.shape}, after shape={model_after_state_dict[key].shape}, "
94+
f"max diff={torch.abs(param - model_after_state_dict[key]).max()}"
9295
)
9396

9497
# Verify amax values are correct
9598
amax_dict = torch.load(quant_amax_file)
9699
assert len(amax_dict) > 0, "amax_dict should not be empty"
97-
assert amax_dict.keys() == amax_state_dict.keys(), f"amax keys mismatch between before and after export"
100+
assert amax_dict.keys() == amax_state_dict.keys(), (
101+
"amax keys mismatch between before and after export"
102+
)
98103

99104

100105
def _test_mcore_vllm_export(tmp_path, quant_cfg, rank, size):
101-
"""Test megatron-core model export for vLLM with fake quantization.
102-
103-
"""
106+
"""Test megatron-core model export for vLLM with fake quantization."""
104107
# Create a tiny mcore GPT model
105108
num_layers = 2
106109
hidden_size = 64
@@ -109,7 +112,7 @@ def _test_mcore_vllm_export(tmp_path, quant_cfg, rank, size):
109112
ffn_hidden_size = 128
110113
max_sequence_length = 32
111114
vocab_size = 64
112-
115+
113116
model = get_mcore_gpt_model(
114117
tensor_model_parallel_size=size,
115118
pipeline_model_parallel_size=1,
@@ -126,7 +129,7 @@ def _test_mcore_vllm_export(tmp_path, quant_cfg, rank, size):
126129
transformer_impl="modelopt",
127130
).cuda()
128131
model.eval()
129-
132+
130133
# Quantize the model
131134
def forward_loop(model):
132135
batch_size = 1
@@ -138,11 +141,8 @@ def forward_loop(model):
138141
attention_mask = attention_mask < 0.5 # Convert to boolean mask
139142
with torch.no_grad():
140143
model(input_ids, position_ids, attention_mask)
141-
142-
model = mtq.quantize(model, quant_cfg, forward_loop)
143-
144-
model_state_dict = deepcopy(model.state_dict())
145144

145+
model = mtq.quantize(model, quant_cfg, forward_loop)
146146
# Create HF config for export
147147
pretrained_config = {
148148
"architectures": ["LlamaForCausalLM"],
@@ -156,14 +156,14 @@ def forward_loop(model):
156156
"num_key_value_heads": num_query_groups,
157157
"torch_dtype": "bfloat16",
158158
}
159-
159+
160160
with open(tmp_path / "config.json", "w") as f:
161161
json.dump(pretrained_config, f)
162162

163163
# Export directory
164164
export_dir = tmp_path / "vllm_export"
165165
export_dir.mkdir(exist_ok=True)
166-
166+
167167
# Export for vLLM
168168
export_mcore_gpt_to_hf(
169169
model,
@@ -176,10 +176,12 @@ def forward_loop(model):
176176
# check if quant_amax.pth file exists
177177
quant_amax_file = export_dir / "quant_amax.pth"
178178
assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}"
179-
179+
180180
# make sure hf_quant_config.json file does not exist
181181
hf_quant_config_file = export_dir / "hf_quant_config.json"
182-
assert not hf_quant_config_file.exists(), f"hf_quant_config.json file should not be created in {export_dir}"
182+
assert not hf_quant_config_file.exists(), (
183+
f"hf_quant_config.json file should not be created in {export_dir}"
184+
)
183185

184186

185187
@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG])
@@ -190,5 +192,3 @@ def test_mcore_vllm_export(tmp_path, quant_cfg):
190192
job=partial(_test_mcore_vllm_export, tmp_path, quant_cfg),
191193
backend="nccl",
192194
)
193-
194-

0 commit comments

Comments
 (0)