Skip to content

Support W4A8 method of AngleSlim tool #6857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.quantization.mode import QuantAlgo, ActivationScheme

TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)

Expand Down Expand Up @@ -238,6 +238,47 @@ def load_modelopt_quant_config(quant_config_file, model_dir, moe_backend):
]
return quant_config, layer_quant_config

@staticmethod
def load_angelslim_quant_config(quant_config_file, model_dir, moe_backend):
quant_config = QuantConfig()
layer_quant_config = None

with open(quant_config_file) as f:
quant_config_dict = json.load(f)

json_quant_configs = quant_config_dict['quantization']

quant_config.quant_algo = QuantAlgo(
json_quant_configs.get('quant_algo', None).upper()) if json_quant_configs.get("quant_algo") else None
# fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES
if quant_config.quant_algo == "fp8_pb_wo":
quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES')
Comment on lines +251 to +255
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Logical issue: string comparison against enum value.

Line 254 compares quant_config.quant_algo (which is now a QuantAlgo enum) against the string "fp8_pb_wo". This will always fail because you're comparing an enum to a string.

Fix the comparison:

-    quant_config.quant_algo = QuantAlgo(
-        json_quant_configs.get('quant_algo', None).upper()) if json_quant_configs.get("quant_algo") else None
-    # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES
-    if quant_config.quant_algo == "fp8_pb_wo":
-        quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES')
+    algo_str = json_quant_configs.get('quant_algo')
+    if algo_str:
+        algo_str = algo_str.upper()
+        # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES
+        if algo_str == "FP8_PB_WO":
+            quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES')
+        else:
+            quant_config.quant_algo = QuantAlgo(algo_str)
+    else:
+        quant_config.quant_algo = None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
quant_config.quant_algo = QuantAlgo(
json_quant_configs.get('quant_algo', None).upper()) if json_quant_configs.get("quant_algo") else None
# fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES
if quant_config.quant_algo == "fp8_pb_wo":
quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES')
algo_str = json_quant_configs.get('quant_algo')
if algo_str:
algo_str = algo_str.upper()
# fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES
if algo_str == "FP8_PB_WO":
quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES')
else:
quant_config.quant_algo = QuantAlgo(algo_str)
else:
quant_config.quant_algo = None
🤖 Prompt for AI Agents
tensorrt_llm/_torch/model_config.py around lines 251 to 255: the code currently
compares the QuantAlgo enum instance to the string "fp8_pb_wo", which always
fails; change the check to compare the enum's name or value to that string (for
example: if quant_config.quant_algo.name.lower() == "fp8_pb_wo") and then set
quant_config.quant_algo = QuantAlgo("FP8_BLOCK_SCALES") as before.


quant_config.kv_cache_quant_algo = QuantAlgo(
json_quant_configs.get("kv_cache_quant_algo").upper()
) if json_quant_configs.get("kv_cache_quant_algo") else None
quant_config.group_size = json_quant_configs.get('group_size', None)
quant_config.exclude_modules = json_quant_configs.get(
'exclude_modules', None)
quant_config.activation_scheme = ActivationScheme(
json_quant_configs.get('activation_scheme', None).upper()
) if json_quant_configs.get("activation_scheme") else None

json_exclude_quant_configs = json_quant_configs.get('exclude_quantization', None)
if json_exclude_quant_configs:
quant_config.exclude_quant_config = {
"quant_algo": QuantAlgo(
json_exclude_quant_configs.get('quant_algo', None).upper()
) if json_exclude_quant_configs.get("quant_algo") else None,
"kv_cache_quant_algo": QuantAlgo(
json_exclude_quant_configs.get("kv_cache_quant_algo").upper()
) if json_exclude_quant_configs.get("kv_cache_quant_algo") else None,
"activation_scheme": ActivationScheme(
json_exclude_quant_configs.get('activation_scheme', None).upper()
) if json_exclude_quant_configs.get("activation_scheme") else None,
}
return quant_config, layer_quant_config

Comment on lines +241 to +281
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for malformed JSON config files.

The new load_angelslim_quant_config method assumes the JSON file is well-formed and contains the expected structure. Add error handling for cases where required keys might be missing or the JSON is malformed.

Consider adding validation like this:

 def load_angelslim_quant_config(quant_config_file, model_dir, moe_backend):
     quant_config = QuantConfig()
     layer_quant_config = None
 
-    with open(quant_config_file) as f:
-        quant_config_dict = json.load(f)
+    try:
+        with open(quant_config_file) as f:
+            quant_config_dict = json.load(f)
+    except (json.JSONDecodeError, IOError) as e:
+        raise ValueError(f"Failed to load angelslim config from {quant_config_file}: {e}")
 
-    json_quant_configs = quant_config_dict['quantization']
+    json_quant_configs = quant_config_dict.get('quantization', {})
+    if not json_quant_configs:
+        raise ValueError(f"Missing 'quantization' section in {quant_config_file}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@staticmethod
def load_angelslim_quant_config(quant_config_file, model_dir, moe_backend):
quant_config = QuantConfig()
layer_quant_config = None
with open(quant_config_file) as f:
quant_config_dict = json.load(f)
json_quant_configs = quant_config_dict['quantization']
quant_config.quant_algo = QuantAlgo(
json_quant_configs.get('quant_algo', None).upper()) if json_quant_configs.get("quant_algo") else None
# fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES
if quant_config.quant_algo == "fp8_pb_wo":
quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES')
quant_config.kv_cache_quant_algo = QuantAlgo(
json_quant_configs.get("kv_cache_quant_algo").upper()
) if json_quant_configs.get("kv_cache_quant_algo") else None
quant_config.group_size = json_quant_configs.get('group_size', None)
quant_config.exclude_modules = json_quant_configs.get(
'exclude_modules', None)
quant_config.activation_scheme = ActivationScheme(
json_quant_configs.get('activation_scheme', None).upper()
) if json_quant_configs.get("activation_scheme") else None
json_exclude_quant_configs = json_quant_configs.get('exclude_quantization', None)
if json_exclude_quant_configs:
quant_config.exclude_quant_config = {
"quant_algo": QuantAlgo(
json_exclude_quant_configs.get('quant_algo', None).upper()
) if json_exclude_quant_configs.get("quant_algo") else None,
"kv_cache_quant_algo": QuantAlgo(
json_exclude_quant_configs.get("kv_cache_quant_algo").upper()
) if json_exclude_quant_configs.get("kv_cache_quant_algo") else None,
"activation_scheme": ActivationScheme(
json_exclude_quant_configs.get('activation_scheme', None).upper()
) if json_exclude_quant_configs.get("activation_scheme") else None,
}
return quant_config, layer_quant_config
@staticmethod
def load_angelslim_quant_config(quant_config_file, model_dir, moe_backend):
quant_config = QuantConfig()
layer_quant_config = None
try:
with open(quant_config_file) as f:
quant_config_dict = json.load(f)
except (json.JSONDecodeError, IOError) as e:
raise ValueError(f"Failed to load angelslim config from {quant_config_file}: {e}")
json_quant_configs = quant_config_dict.get('quantization', {})
if not json_quant_configs:
raise ValueError(f"Missing 'quantization' section in {quant_config_file}")
quant_config.quant_algo = QuantAlgo(
json_quant_configs.get('quant_algo', None).upper()
) if json_quant_configs.get("quant_algo") else None
# fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES
if quant_config.quant_algo == "fp8_pb_wo":
quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES')
quant_config.kv_cache_quant_algo = QuantAlgo(
json_quant_configs.get("kv_cache_quant_algo").upper()
) if json_quant_configs.get("kv_cache_quant_algo") else None
quant_config.group_size = json_quant_configs.get('group_size', None)
quant_config.exclude_modules = json_quant_configs.get(
'exclude_modules', None)
quant_config.activation_scheme = ActivationScheme(
json_quant_configs.get('activation_scheme', None).upper()
) if json_quant_configs.get("activation_scheme") else None
json_exclude_quant_configs = json_quant_configs.get('exclude_quantization', None)
if json_exclude_quant_configs:
quant_config.exclude_quant_config = {
"quant_algo": QuantAlgo(
json_exclude_quant_configs.get('quant_algo', None).upper()
) if json_exclude_quant_configs.get("quant_algo") else None,
"kv_cache_quant_algo": QuantAlgo(
json_exclude_quant_configs.get("kv_cache_quant_algo").upper()
) if json_exclude_quant_configs.get("kv_cache_quant_algo") else None,
"activation_scheme": ActivationScheme(
json_exclude_quant_configs.get('activation_scheme', None).upper()
) if json_exclude_quant_configs.get("activation_scheme") else None,
}
return quant_config, layer_quant_config

@staticmethod
def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False):
quant_algo = ModelConfig.override_quant_algo()
Expand Down Expand Up @@ -282,6 +323,40 @@ def load_hf_quant_config(hf_quant_config, moe_backend):
'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv',
'embedding', 'unembedding'
]
elif hf_quant_config.get("quant_method") == "w4a8_awq":
quant_config.quant_algo = QuantAlgo.W4A8_AWQ
else:
raise NotImplementedError(f"Unsupported quantization_config: {hf_quant_config}.")

# set kv_cache_quant_algo
quant_config.kv_cache_quant_algo = QuantAlgo(hf_quant_config.get("kv_cache_quant_method").upper()) \
if hf_quant_config.get("kv_cache_quant_method") else None
# set activation_scheme
quant_config.activation_scheme = ActivationScheme(hf_quant_config.get("activation_scheme").upper()) \
if hf_quant_config.get("activation_scheme") else None
# set exclude_modules
if quant_config.exclude_modules:
if hf_quant_config.get("ignored_modules"):
quant_config.exclude_modules += hf_quant_config.get("ignored_modules")
else:
quant_config.exclude_modules = hf_quant_config.get("ignored_modules")

# set exclude_quant_config
hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config")
if hf_ignored_quantization_config:
quant_config.exclude_quant_config = {
"quant_algo": QuantAlgo(
hf_ignored_quantization_config.get("quant_method").upper()
) if hf_ignored_quantization_config.get("quant_method") else None,
"kv_cache_quant_algo": QuantAlgo(
hf_ignored_quantization_config.get("kv_cache_quant_method").upper()
) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None,
"activation_scheme": ActivationScheme(
hf_ignored_quantization_config.get("activation_scheme").upper()
) if hf_ignored_quantization_config.get("activation_scheme") else None,
}

logger.info(f"Load quantization config from pretrained config, quant_config: {quant_config}")

return quant_config, layer_quant_config

Expand Down Expand Up @@ -360,6 +435,9 @@ def from_pretrained(cls,
if (quant_config_file := model_dir / 'hf_quant_config.json').exists():
quant_config, layer_quant_config = cls.load_modelopt_quant_config(
quant_config_file, model_dir, moe_backend)
elif (quant_config_file := model_dir / 'angelslim_hf_quant_config.json').exists():
quant_config, layer_quant_config = cls.load_angelslim_quant_config(
quant_config_file, model_dir, moe_backend)
Comment on lines +438 to +440
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify the consistency of the naming convention for config files.

The code now supports both angelslim_hf_quant_config.json and hf_quant_config.json. Ensure this naming is consistent with the tool's documentation and expected usage.


🏁 Script executed:

#!/bin/bash
# Check for any existing references to these config file names in the codebase
echo "Searching for references to angelslim_hf_quant_config.json:"
rg "angelslim_hf_quant_config" --type py -A 2 -B 2

echo -e "\nSearching for references to hf_quant_config.json:"
rg "hf_quant_config\.json" --type py -A 2 -B 2

echo -e "\nChecking for any documentation about these config files:"
fd -e md -e rst | xargs rg -l "angelslim_hf_quant_config|hf_quant_config"

Length of output: 8264


Update documentation to include the new AngelsLim config filename

The code now recognizes both hf_quant_config.json and angelslim_hf_quant_config.json, but our docs and examples only mention the former. Please update all references so users know both are supported:

• docs/source/performance/perf-benchmarking.md
• examples/models/core/deepseek_v3/README.md

– Add a note explaining that if an angelslim_hf_quant_config.json is present in the model directory, it will be loaded via load_angelslim_quant_config()
– Ensure any code snippets or CLI examples show both filenames where applicable

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/model_config.py around lines 438 to 440, the code now
accepts an additional AngelsLim config filename
(angelslim_hf_quant_config.json); update the documentation and example README to
mention both hf_quant_config.json and angelslim_hf_quant_config.json. Edit
docs/source/performance/perf-benchmarking.md and
examples/models/core/deepseek_v3/README.md to add a short note that if
angelslim_hf_quant_config.json exists in the model directory it will be loaded
via load_angelslim_quant_config(), and update any code snippets or CLI examples
to show both filenames where applicable (e.g., list both filenames in examples
and usage text).

# quantized ckpt in other formats
elif hasattr(pretrained_config, "quantization_config"):
hf_quant_config = pretrained_config.quantization_config
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
if names[-1] == "kv_b_proj":
# TODO: remove weight_dequant after enabling fp8_bmm
dequant_kv_b_proj = self.model_config.quant_config.is_module_excluded_from_quantization(
names[-1])
names[-1]) and self.model_config.quant_config.exclude_quant_config is None
if dequant_kv_b_proj:
kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans_dequant(
name)
Expand Down
9 changes: 8 additions & 1 deletion tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,14 @@ def apply_quant_config_exclude_modules(self):
kv_cache_quant_algo = None
if quant_config:
kv_cache_quant_algo = quant_config.kv_cache_quant_algo
new_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo)
quant_algo = None
activation_scheme = None
exclude_quant_config = quant_config.exclude_quant_config
if exclude_quant_config:
quant_algo = exclude_quant_config.get("quant_algo", None)
activation_scheme = exclude_quant_config.get("activation_scheme", None)
new_config = QuantConfig(
quant_algo=quant_algo, kv_cache_quant_algo=kv_cache_quant_algo, activation_scheme=activation_scheme)

if quant_config is not None:
if quant_config.exclude_modules is not None:
Expand Down
50 changes: 36 additions & 14 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,18 @@ def load_expert_weights_to_dst(
load_expert_ids: List[int], dst_w3_w1_weights_tensor: torch.Tensor,
dst_w2_weights_tensor: torch.Tensor,
dst_w3_w1_bias_tensor: Optional[torch.Tensor],
dst_w2_bias_tensor: Optional[torch.Tensor]):
dst_w2_bias_tensor: Optional[torch.Tensor],
weight_name: str = "weight"):
# Multithread weight load is superseded by prefetch_files() in model_engine.py
# Also, threading adds overhead in order to protect shuffle index cache with critical section.
for local_slot_id, expert_id in enumerate(load_expert_ids):
# expert_idx is the local slot index of current rank
expert_idx = local_slot_id

if weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_weight = weights[f"{expert_id}.w1.weight"]
w3_weight = weights[f"{expert_id}.w3.weight"]
w2_weight = weights[f"{expert_id}.w2.weight"]
w1_weight = weights[f"{expert_id}.w1.{weight_name}"]
w3_weight = weights[f"{expert_id}.w3.{weight_name}"]
w2_weight = weights[f"{expert_id}.w2.{weight_name}"]
if module.bias:
w1_bias = weights[f"{expert_id}.w1.bias"]
w3_bias = weights[f"{expert_id}.w3.bias"]
Expand Down Expand Up @@ -251,14 +252,16 @@ def load_expert_weights_to_dst(
dst_w2_bias_tensor.data[expert_idx])

def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):
weight_loading_mode: MoEWeightLoadingMode,
weight_name: str = "weight"):

self.load_expert_weights_to_dst(
module, weights, weight_loading_mode,
module.initial_local_expert_ids, module.w3_w1_weight.data,
module.w2_weight.data,
module.w3_w1_bias.data if module.bias else None,
module.w2_bias.data if module.bias else None)
module.w2_bias.data if module.bias else None,
weight_name)

self.load_quant_scales(module, weights)
# Re-setup quant scales after loading weights as the tensors may have been modified.
Expand Down Expand Up @@ -953,6 +956,11 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype),
non_blocking=True)

def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode,
weight_name: str = "qweight"):
super().load_weights(module, weights, weight_loading_mode, weight_name)

def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
assert self.device.type == "cuda"

Expand All @@ -974,7 +982,13 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
module.fc31_act_scale.data.copy_(
torch.ones_like(module.fc31_act_scale) *
(1 / all_w3_w1_input_scales_max))
module.fc31_alpha.data.copy_((torch.ones_like(module.fc31_alpha) *
all_w3_w1_scales_fp8_max = []
for expert_id in module.initial_local_expert_ids:
w1_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w1.weight_scale"])
w3_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w3.weight_scale"])
all_w3_w1_scales_fp8_max.append(torch.max(w3_weight_scale_fp8, w1_weight_scale_fp8))
all_w3_w1_scales_fp8_max = torch.stack(all_w3_w1_scales_fp8_max).reshape(module.fc31_alpha.shape)
module.fc31_alpha.data.copy_((all_w3_w1_scales_fp8_max *
all_w3_w1_input_scales_max).float())

all_w3_scales = [
Expand All @@ -985,17 +999,19 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
device=self.device)
for expert_id in module.initial_local_expert_ids
]
all_w3_scales = torch.stack(all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2)
all_w1_scales = [
load_weight_shard(weights[f"{expert_id}.w1.weight_scale_inv"],
load_weight_shard(weights[f"{expert_id}.w1.weight_scale.int4"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Inconsistent weight scale key naming: weight_scale.int4 vs weight_scale_inv.

Line 1004 uses f"{expert_id}.w1.weight_scale.int4" while line 995 uses f"{expert_id}.w3.weight_scale_inv". This inconsistency could indicate different weight formats or be a bug.

Please verify if this naming difference is intentional:

#!/bin/bash
# Check for both naming patterns in the codebase
echo "Searching for weight_scale.int4 pattern:"
rg "weight_scale\.int4" --type py -B 2 -A 2

echo -e "\nSearching for weight_scale_inv pattern:"
rg "weight_scale_inv" --type py -B 2 -A 2
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/quantization.py around line 1004 there
is an inconsistent key name: line 1004 uses f"{expert_id}.w1.weight_scale.int4"
while earlier (line ~995) uses f"{expert_id}.w3.weight_scale_inv"; confirm
whether the correct stored key is weight_scale_inv or weight_scale.int4 by
searching the repo for both patterns, then make the keys consistent (prefer
using the canonical weight_scale_inv if other shards use that naming), update
the load_weight_shard call to use the canonical key across all shards, and add a
brief inline comment explaining the chosen convention so future readers know
which format is expected.

module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=self.device)
for expert_id in module.initial_local_expert_ids
]
all_w1_scales = torch.stack(all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2)
all_w3_w1_scales = torch.cat(
[torch.stack(all_w3_scales),
torch.stack(all_w1_scales)], dim=-2)
[all_w3_scales,
all_w1_scales], dim=-2)
Comment on lines +1002 to +1014
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential division by zero when scales are zero.

The code divides by all_w3_w1_scales_fp8_max and all_w2_scales_fp8 without checking if these values are zero. This could lead to runtime errors or inf/nan values.

Add validation for zero scales:

 all_w3_scales = torch.stack(all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2)
+if torch.any(all_w3_w1_scales_fp8_max == 0):
+    raise ValueError("Found zero FP8 weight scales, which would cause division by zero")
 
 all_w1_scales = torch.stack(all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2)
 
 all_w2_scales = torch.stack(all_w2_scales) / all_w2_scales_fp8.unsqueeze(2)
+if torch.any(all_w2_scales_fp8 == 0):
+    raise ValueError("Found zero FP8 weight scales for w2, which would cause division by zero")

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/quantization.py around lines 1002 to
1014, the code divides tensors by all_w3_w1_scales_fp8_max and all_w2_scales_fp8
without guarding against zeros; add a defensive check to avoid divide-by-zero by
replacing zero (or near-zero) elements with a small safe epsilon (e.g., 1e-6) or
use torch.clamp_min / torch.where to ensure the denominators are >= epsilon
before performing the divisions, and optionally log or assert if any
replacements occurred to aid debugging.

if module.sm_version == 89:
w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(module.dtype)
else:
Expand Down Expand Up @@ -1023,22 +1039,28 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
module.fc2_act_scale.data.copy_(
torch.ones_like(module.fc2_act_scale) *
(1 / all_w2_input_scales_max))
module.fc2_alpha.data.copy_((torch.ones_like(module.fc2_alpha) *
all_w2_scales_fp8 = [
load_weight_shard(weights[f"{expert_id}.w2.weight_scale"])
for expert_id in module.initial_local_expert_ids
]
all_w2_scales_fp8 = torch.stack(all_w2_scales_fp8).reshape(module.fc2_alpha.shape)
module.fc2_alpha.data.copy_((all_w2_scales_fp8 *
all_w2_input_scales_max).float())

all_w2_scales = [
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"],
load_weight_shard(weights[f"{expert_id}.w2.weight_scale.int4"],
module.tp_size,
module.tp_rank,
TensorParallelMode.ROW,
device=self.device)
for expert_id in module.initial_local_expert_ids
]
all_w2_scales = torch.stack(all_w2_scales) / all_w2_scales_fp8.unsqueeze(2)
if module.sm_version == 89:
w2_scales = torch.stack(all_w2_scales).to(torch.float16).view(
w2_scales = all_w2_scales.to(torch.float16).view(
module.dtype)
else:
w2_scales = torch.stack(all_w2_scales).to(torch.bfloat16).view(
w2_scales = all_w2_scales.to(torch.bfloat16).view(
module.dtype)
w2_s_shape = w2_scales.shape
w2_scales_interleaved = w2_scales.reshape(
Expand Down
35 changes: 34 additions & 1 deletion tensorrt_llm/llmapi/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..logger import logger
from ..mapping import Mapping
from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM
from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig
from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig, ActivationScheme
from ..module import Module
from .build_cache import (BuildCache, BuildCacheConfig, CachedStage,
get_build_cache_config_from_env)
Expand Down Expand Up @@ -435,10 +435,43 @@ def _update_from_hf_quant_config(self) -> bool:
'block.*.attn.out', 'block.*.mlp.gate',
'block.*.attn.qkv', 'embedding', 'unembedding'
]
elif hf_quant_config.get("quant_method") == "w4a8_awq":
quant_config.quant_algo = QuantAlgo.W4A8_AWQ
else:
raise NotImplementedError(
f"Unsupported quantization_config: {hf_quant_config}.")

# set kv_cache_quant_algo
quant_config.kv_cache_quant_algo = QuantAlgo(
hf_quant_config.get("kv_cache_quant_method").upper()
) if hf_quant_config.get("kv_cache_quant_method") else None
# set activation_scheme
quant_config.activation_scheme = ActivationScheme(
hf_quant_config.get("activation_scheme").upper()
) if hf_quant_config.get("activation_scheme") else None
# set exclude_modules
if quant_config.exclude_modules:
if hf_quant_config.get("ignored_modules"):
quant_config.exclude_modules += hf_quant_config.get("ignored_modules")
else:
quant_config.exclude_modules = hf_quant_config.get("ignored_modules")
# set exclude_quant_config
hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config")
if hf_ignored_quantization_config:
quant_config.exclude_quant_config = {
"quant_algo": QuantAlgo(
hf_ignored_quantization_config.get("quant_method").upper()
) if hf_ignored_quantization_config.get("quant_method") else None,
"kv_cache_quant_algo": QuantAlgo(
hf_ignored_quantization_config.get("kv_cache_quant_method").upper()
) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None,
"activation_scheme": ActivationScheme(
hf_ignored_quantization_config.get("activation_scheme").upper()
) if hf_ignored_quantization_config.get("activation_scheme") else None,
}
logger.info(
f"Detected quantization_config: {quant_config}."
)
return True

return False
Expand Down
6 changes: 5 additions & 1 deletion tensorrt_llm/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
WeightOnlyQuantLinear,
WeightOnlyQuantRowLinear)
from ..quantization.mode import (KV_CACHE_QUANT_ALGO_LIST, QUANT_ALGO_LIST,
W8A8_SQ_PLUGIN_LIST, QuantAlgo)
W8A8_SQ_PLUGIN_LIST, QuantAlgo, ActivationScheme)
from ..quantization.utils import fp4_utils
from ..top_model_mixin import TopModelMixin
from .convert_utils import weight_only_quantize_dict
Expand Down Expand Up @@ -140,6 +140,8 @@ class QuantConfig:
pre_quant_scale (bool): Whether to use pre-quant scale for quantization. Defaults to False.
exclude_modules (List[str], optional): The module name patterns that are skipped in quantization. Defaults to None.
mamba_ssm_cache_dtype (str, optional): The data type for mamba SSM cache. Defaults to None.
exclude_quant_config (Dict, optional): The model of exclude_modules will use exclude_quant_config.
activation_scheme (tensorrt_llm.quantization.mode.ActivationScheme, optional): The input of activation quantize scheme.
"""
quant_algo: Optional[QuantAlgo] = None
kv_cache_quant_algo: Optional[QuantAlgo] = None
Expand All @@ -151,6 +153,8 @@ class QuantConfig:
pre_quant_scale: bool = False
exclude_modules: Optional[List[str]] = None
mamba_ssm_cache_dtype: Optional[str] = None
exclude_quant_config: Optional[Dict] = None
activation_scheme: Optional[ActivationScheme] = None

@cached_property
def quant_mode(self) -> QuantModeWrapper:
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,8 @@ class GroupwiseQuantAlgo:
PRE_QUANT_SCALE = 4
W4A8_ALPHA = 8
INT8_WEIGHT = 16


class ActivationScheme(StrEnum, metaclass=BaseEnumMeta):
STATIC = auto()
DYNAMIC = auto()
Loading