-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Support W4A8 method of AngleSlim tool #6857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -15,7 +15,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensorrt_llm.logger import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensorrt_llm.mapping import Mapping | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensorrt_llm.models.modeling_utils import QuantConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensorrt_llm.quantization.mode import QuantAlgo | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensorrt_llm.quantization.mode import QuantAlgo, ActivationScheme | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -238,6 +238,47 @@ def load_modelopt_quant_config(quant_config_file, model_dir, moe_backend): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return quant_config, layer_quant_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def load_angelslim_quant_config(quant_config_file, model_dir, moe_backend): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config = QuantConfig() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
layer_quant_config = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
with open(quant_config_file) as f: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config_dict = json.load(f) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
json_quant_configs = quant_config_dict['quantization'] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.quant_algo = QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
json_quant_configs.get('quant_algo', None).upper()) if json_quant_configs.get("quant_algo") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if quant_config.quant_algo == "fp8_pb_wo": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.kv_cache_quant_algo = QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
json_quant_configs.get("kv_cache_quant_algo").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) if json_quant_configs.get("kv_cache_quant_algo") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.group_size = json_quant_configs.get('group_size', None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.exclude_modules = json_quant_configs.get( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'exclude_modules', None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.activation_scheme = ActivationScheme( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
json_quant_configs.get('activation_scheme', None).upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) if json_quant_configs.get("activation_scheme") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
json_exclude_quant_configs = json_quant_configs.get('exclude_quantization', None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if json_exclude_quant_configs: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.exclude_quant_config = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"quant_algo": QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
json_exclude_quant_configs.get('quant_algo', None).upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) if json_exclude_quant_configs.get("quant_algo") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"kv_cache_quant_algo": QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
json_exclude_quant_configs.get("kv_cache_quant_algo").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) if json_exclude_quant_configs.get("kv_cache_quant_algo") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"activation_scheme": ActivationScheme( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
json_exclude_quant_configs.get('activation_scheme', None).upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) if json_exclude_quant_configs.get("activation_scheme") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return quant_config, layer_quant_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+241
to
+281
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling for malformed JSON config files. The new Consider adding validation like this: def load_angelslim_quant_config(quant_config_file, model_dir, moe_backend):
quant_config = QuantConfig()
layer_quant_config = None
- with open(quant_config_file) as f:
- quant_config_dict = json.load(f)
+ try:
+ with open(quant_config_file) as f:
+ quant_config_dict = json.load(f)
+ except (json.JSONDecodeError, IOError) as e:
+ raise ValueError(f"Failed to load angelslim config from {quant_config_file}: {e}")
- json_quant_configs = quant_config_dict['quantization']
+ json_quant_configs = quant_config_dict.get('quantization', {})
+ if not json_quant_configs:
+ raise ValueError(f"Missing 'quantization' section in {quant_config_file}") 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_algo = ModelConfig.override_quant_algo() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -282,6 +323,40 @@ def load_hf_quant_config(hf_quant_config, moe_backend): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
'embedding', 'unembedding' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif hf_quant_config.get("quant_method") == "w4a8_awq": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.quant_algo = QuantAlgo.W4A8_AWQ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise NotImplementedError(f"Unsupported quantization_config: {hf_quant_config}.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# set kv_cache_quant_algo | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.kv_cache_quant_algo = QuantAlgo(hf_quant_config.get("kv_cache_quant_method").upper()) \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if hf_quant_config.get("kv_cache_quant_method") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# set activation_scheme | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.activation_scheme = ActivationScheme(hf_quant_config.get("activation_scheme").upper()) \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if hf_quant_config.get("activation_scheme") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# set exclude_modules | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if quant_config.exclude_modules: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if hf_quant_config.get("ignored_modules"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.exclude_modules += hf_quant_config.get("ignored_modules") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.exclude_modules = hf_quant_config.get("ignored_modules") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# set exclude_quant_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if hf_ignored_quantization_config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config.exclude_quant_config = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"quant_algo": QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
hf_ignored_quantization_config.get("quant_method").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) if hf_ignored_quantization_config.get("quant_method") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"kv_cache_quant_algo": QuantAlgo( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
hf_ignored_quantization_config.get("kv_cache_quant_method").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"activation_scheme": ActivationScheme( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
hf_ignored_quantization_config.get("activation_scheme").upper() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) if hf_ignored_quantization_config.get("activation_scheme") else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.info(f"Load quantization config from pretrained config, quant_config: {quant_config}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return quant_config, layer_quant_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -360,6 +435,9 @@ def from_pretrained(cls, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (quant_config_file := model_dir / 'hf_quant_config.json').exists(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config, layer_quant_config = cls.load_modelopt_quant_config( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config_file, model_dir, moe_backend) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif (quant_config_file := model_dir / 'angelslim_hf_quant_config.json').exists(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config, layer_quant_config = cls.load_angelslim_quant_config( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_config_file, model_dir, moe_backend) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+438
to
+440
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainVerify the consistency of the naming convention for config files. The code now supports both 🏁 Script executed: #!/bin/bash
# Check for any existing references to these config file names in the codebase
echo "Searching for references to angelslim_hf_quant_config.json:"
rg "angelslim_hf_quant_config" --type py -A 2 -B 2
echo -e "\nSearching for references to hf_quant_config.json:"
rg "hf_quant_config\.json" --type py -A 2 -B 2
echo -e "\nChecking for any documentation about these config files:"
fd -e md -e rst | xargs rg -l "angelslim_hf_quant_config|hf_quant_config" Length of output: 8264 Update documentation to include the new AngelsLim config filename The code now recognizes both • docs/source/performance/perf-benchmarking.md – Add a note explaining that if an 🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# quantized ckpt in other formats | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif hasattr(pretrained_config, "quantization_config"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
hf_quant_config = pretrained_config.quantization_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -206,17 +206,18 @@ def load_expert_weights_to_dst( | |
load_expert_ids: List[int], dst_w3_w1_weights_tensor: torch.Tensor, | ||
dst_w2_weights_tensor: torch.Tensor, | ||
dst_w3_w1_bias_tensor: Optional[torch.Tensor], | ||
dst_w2_bias_tensor: Optional[torch.Tensor]): | ||
dst_w2_bias_tensor: Optional[torch.Tensor], | ||
weight_name: str = "weight"): | ||
# Multithread weight load is superseded by prefetch_files() in model_engine.py | ||
# Also, threading adds overhead in order to protect shuffle index cache with critical section. | ||
for local_slot_id, expert_id in enumerate(load_expert_ids): | ||
# expert_idx is the local slot index of current rank | ||
expert_idx = local_slot_id | ||
|
||
if weight_loading_mode == MoEWeightLoadingMode.VANILLA: | ||
w1_weight = weights[f"{expert_id}.w1.weight"] | ||
w3_weight = weights[f"{expert_id}.w3.weight"] | ||
w2_weight = weights[f"{expert_id}.w2.weight"] | ||
w1_weight = weights[f"{expert_id}.w1.{weight_name}"] | ||
w3_weight = weights[f"{expert_id}.w3.{weight_name}"] | ||
w2_weight = weights[f"{expert_id}.w2.{weight_name}"] | ||
if module.bias: | ||
w1_bias = weights[f"{expert_id}.w1.bias"] | ||
w3_bias = weights[f"{expert_id}.w3.bias"] | ||
|
@@ -251,14 +252,16 @@ def load_expert_weights_to_dst( | |
dst_w2_bias_tensor.data[expert_idx]) | ||
|
||
def load_weights(self, module: torch.nn.Module, weights: List[Dict], | ||
weight_loading_mode: MoEWeightLoadingMode): | ||
weight_loading_mode: MoEWeightLoadingMode, | ||
weight_name: str = "weight"): | ||
|
||
self.load_expert_weights_to_dst( | ||
module, weights, weight_loading_mode, | ||
module.initial_local_expert_ids, module.w3_w1_weight.data, | ||
module.w2_weight.data, | ||
module.w3_w1_bias.data if module.bias else None, | ||
module.w2_bias.data if module.bias else None) | ||
module.w2_bias.data if module.bias else None, | ||
weight_name) | ||
|
||
self.load_quant_scales(module, weights) | ||
# Re-setup quant scales after loading weights as the tensors may have been modified. | ||
|
@@ -953,6 +956,11 @@ def load_expert_w2_weight(self, module: torch.nn.Module, | |
dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), | ||
non_blocking=True) | ||
|
||
def load_weights(self, module: torch.nn.Module, weights: List[Dict], | ||
weight_loading_mode: MoEWeightLoadingMode, | ||
weight_name: str = "qweight"): | ||
super().load_weights(module, weights, weight_loading_mode, weight_name) | ||
|
||
def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | ||
assert self.device.type == "cuda" | ||
|
||
|
@@ -974,7 +982,13 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
module.fc31_act_scale.data.copy_( | ||
torch.ones_like(module.fc31_act_scale) * | ||
(1 / all_w3_w1_input_scales_max)) | ||
module.fc31_alpha.data.copy_((torch.ones_like(module.fc31_alpha) * | ||
all_w3_w1_scales_fp8_max = [] | ||
for expert_id in module.initial_local_expert_ids: | ||
w1_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w1.weight_scale"]) | ||
w3_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w3.weight_scale"]) | ||
all_w3_w1_scales_fp8_max.append(torch.max(w3_weight_scale_fp8, w1_weight_scale_fp8)) | ||
all_w3_w1_scales_fp8_max = torch.stack(all_w3_w1_scales_fp8_max).reshape(module.fc31_alpha.shape) | ||
module.fc31_alpha.data.copy_((all_w3_w1_scales_fp8_max * | ||
all_w3_w1_input_scales_max).float()) | ||
|
||
all_w3_scales = [ | ||
|
@@ -985,17 +999,19 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
device=self.device) | ||
for expert_id in module.initial_local_expert_ids | ||
] | ||
all_w3_scales = torch.stack(all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) | ||
all_w1_scales = [ | ||
load_weight_shard(weights[f"{expert_id}.w1.weight_scale_inv"], | ||
load_weight_shard(weights[f"{expert_id}.w1.weight_scale.int4"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inconsistent weight scale key naming: weight_scale.int4 vs weight_scale_inv. Line 1004 uses Please verify if this naming difference is intentional: #!/bin/bash
# Check for both naming patterns in the codebase
echo "Searching for weight_scale.int4 pattern:"
rg "weight_scale\.int4" --type py -B 2 -A 2
echo -e "\nSearching for weight_scale_inv pattern:"
rg "weight_scale_inv" --type py -B 2 -A 2 🤖 Prompt for AI Agents
|
||
module.tp_size, | ||
module.tp_rank, | ||
TensorParallelMode.COLUMN, | ||
device=self.device) | ||
for expert_id in module.initial_local_expert_ids | ||
] | ||
all_w1_scales = torch.stack(all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) | ||
all_w3_w1_scales = torch.cat( | ||
[torch.stack(all_w3_scales), | ||
torch.stack(all_w1_scales)], dim=-2) | ||
[all_w3_scales, | ||
all_w1_scales], dim=-2) | ||
Comment on lines
+1002
to
+1014
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential division by zero when scales are zero. The code divides by Add validation for zero scales: all_w3_scales = torch.stack(all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2)
+if torch.any(all_w3_w1_scales_fp8_max == 0):
+ raise ValueError("Found zero FP8 weight scales, which would cause division by zero")
all_w1_scales = torch.stack(all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2)
all_w2_scales = torch.stack(all_w2_scales) / all_w2_scales_fp8.unsqueeze(2)
+if torch.any(all_w2_scales_fp8 == 0):
+ raise ValueError("Found zero FP8 weight scales for w2, which would cause division by zero")
🤖 Prompt for AI Agents
|
||
if module.sm_version == 89: | ||
w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(module.dtype) | ||
else: | ||
|
@@ -1023,22 +1039,28 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
module.fc2_act_scale.data.copy_( | ||
torch.ones_like(module.fc2_act_scale) * | ||
(1 / all_w2_input_scales_max)) | ||
module.fc2_alpha.data.copy_((torch.ones_like(module.fc2_alpha) * | ||
all_w2_scales_fp8 = [ | ||
load_weight_shard(weights[f"{expert_id}.w2.weight_scale"]) | ||
for expert_id in module.initial_local_expert_ids | ||
] | ||
all_w2_scales_fp8 = torch.stack(all_w2_scales_fp8).reshape(module.fc2_alpha.shape) | ||
module.fc2_alpha.data.copy_((all_w2_scales_fp8 * | ||
all_w2_input_scales_max).float()) | ||
|
||
all_w2_scales = [ | ||
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"], | ||
load_weight_shard(weights[f"{expert_id}.w2.weight_scale.int4"], | ||
module.tp_size, | ||
module.tp_rank, | ||
TensorParallelMode.ROW, | ||
device=self.device) | ||
for expert_id in module.initial_local_expert_ids | ||
] | ||
all_w2_scales = torch.stack(all_w2_scales) / all_w2_scales_fp8.unsqueeze(2) | ||
if module.sm_version == 89: | ||
w2_scales = torch.stack(all_w2_scales).to(torch.float16).view( | ||
w2_scales = all_w2_scales.to(torch.float16).view( | ||
module.dtype) | ||
else: | ||
w2_scales = torch.stack(all_w2_scales).to(torch.bfloat16).view( | ||
w2_scales = all_w2_scales.to(torch.bfloat16).view( | ||
module.dtype) | ||
w2_s_shape = w2_scales.shape | ||
w2_scales_interleaved = w2_scales.reshape( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logical issue: string comparison against enum value.
Line 254 compares
quant_config.quant_algo
(which is now aQuantAlgo
enum) against the string"fp8_pb_wo"
. This will always fail because you're comparing an enum to a string.Fix the comparison:
📝 Committable suggestion
🤖 Prompt for AI Agents