From 15fbe20379cca3db10c7e274d886c9cfaf72092e Mon Sep 17 00:00:00 2001 From: beipingpan Date: Wed, 13 Aug 2025 16:17:36 +0800 Subject: [PATCH] Support W4A8 method of AngleSlim tool --- tensorrt_llm/_torch/model_config.py | 80 ++++++++++++++++++- .../_torch/models/modeling_deepseekv3.py | 2 +- tensorrt_llm/_torch/models/modeling_utils.py | 9 ++- .../_torch/modules/fused_moe/quantization.py | 50 ++++++++---- tensorrt_llm/llmapi/llm_utils.py | 35 +++++++- tensorrt_llm/models/modeling_utils.py | 6 +- tensorrt_llm/quantization/mode.py | 5 ++ 7 files changed, 168 insertions(+), 19 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 5bc9e7870f..d1366ce3e9 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -15,7 +15,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm.quantization.mode import QuantAlgo +from tensorrt_llm.quantization.mode import QuantAlgo, ActivationScheme TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) @@ -238,6 +238,47 @@ def load_modelopt_quant_config(quant_config_file, model_dir, moe_backend): ] return quant_config, layer_quant_config + @staticmethod + def load_angelslim_quant_config(quant_config_file, model_dir, moe_backend): + quant_config = QuantConfig() + layer_quant_config = None + + with open(quant_config_file) as f: + quant_config_dict = json.load(f) + + json_quant_configs = quant_config_dict['quantization'] + + quant_config.quant_algo = QuantAlgo( + json_quant_configs.get('quant_algo', None).upper()) if json_quant_configs.get("quant_algo") else None + # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES + if quant_config.quant_algo == "fp8_pb_wo": + quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES') + + quant_config.kv_cache_quant_algo = QuantAlgo( + json_quant_configs.get("kv_cache_quant_algo").upper() + ) if json_quant_configs.get("kv_cache_quant_algo") else None + quant_config.group_size = json_quant_configs.get('group_size', None) + quant_config.exclude_modules = json_quant_configs.get( + 'exclude_modules', None) + quant_config.activation_scheme = ActivationScheme( + json_quant_configs.get('activation_scheme', None).upper() + ) if json_quant_configs.get("activation_scheme") else None + + json_exclude_quant_configs = json_quant_configs.get('exclude_quantization', None) + if json_exclude_quant_configs: + quant_config.exclude_quant_config = { + "quant_algo": QuantAlgo( + json_exclude_quant_configs.get('quant_algo', None).upper() + ) if json_exclude_quant_configs.get("quant_algo") else None, + "kv_cache_quant_algo": QuantAlgo( + json_exclude_quant_configs.get("kv_cache_quant_algo").upper() + ) if json_exclude_quant_configs.get("kv_cache_quant_algo") else None, + "activation_scheme": ActivationScheme( + json_exclude_quant_configs.get('activation_scheme', None).upper() + ) if json_exclude_quant_configs.get("activation_scheme") else None, + } + return quant_config, layer_quant_config + @staticmethod def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False): quant_algo = ModelConfig.override_quant_algo() @@ -282,6 +323,40 @@ def load_hf_quant_config(hf_quant_config, moe_backend): 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', 'embedding', 'unembedding' ] + elif hf_quant_config.get("quant_method") == "w4a8_awq": + quant_config.quant_algo = QuantAlgo.W4A8_AWQ + else: + raise NotImplementedError(f"Unsupported quantization_config: {hf_quant_config}.") + + # set kv_cache_quant_algo + quant_config.kv_cache_quant_algo = QuantAlgo(hf_quant_config.get("kv_cache_quant_method").upper()) \ + if hf_quant_config.get("kv_cache_quant_method") else None + # set activation_scheme + quant_config.activation_scheme = ActivationScheme(hf_quant_config.get("activation_scheme").upper()) \ + if hf_quant_config.get("activation_scheme") else None + # set exclude_modules + if quant_config.exclude_modules: + if hf_quant_config.get("ignored_modules"): + quant_config.exclude_modules += hf_quant_config.get("ignored_modules") + else: + quant_config.exclude_modules = hf_quant_config.get("ignored_modules") + + # set exclude_quant_config + hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config") + if hf_ignored_quantization_config: + quant_config.exclude_quant_config = { + "quant_algo": QuantAlgo( + hf_ignored_quantization_config.get("quant_method").upper() + ) if hf_ignored_quantization_config.get("quant_method") else None, + "kv_cache_quant_algo": QuantAlgo( + hf_ignored_quantization_config.get("kv_cache_quant_method").upper() + ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None, + "activation_scheme": ActivationScheme( + hf_ignored_quantization_config.get("activation_scheme").upper() + ) if hf_ignored_quantization_config.get("activation_scheme") else None, + } + + logger.info(f"Load quantization config from pretrained config, quant_config: {quant_config}") return quant_config, layer_quant_config @@ -360,6 +435,9 @@ def from_pretrained(cls, if (quant_config_file := model_dir / 'hf_quant_config.json').exists(): quant_config, layer_quant_config = cls.load_modelopt_quant_config( quant_config_file, model_dir, moe_backend) + elif (quant_config_file := model_dir / 'angelslim_hf_quant_config.json').exists(): + quant_config, layer_quant_config = cls.load_angelslim_quant_config( + quant_config_file, model_dir, moe_backend) # quantized ckpt in other formats elif hasattr(pretrained_config, "quantization_config"): hf_quant_config = pretrained_config.quantization_config diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index a5a61d9a7d..81e8bfe76e 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -1360,7 +1360,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, if names[-1] == "kv_b_proj": # TODO: remove weight_dequant after enabling fp8_bmm dequant_kv_b_proj = self.model_config.quant_config.is_module_excluded_from_quantization( - names[-1]) + names[-1]) and self.model_config.quant_config.exclude_quant_config is None if dequant_kv_b_proj: kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans_dequant( name) diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index 238ac97ffe..6d4765558c 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -477,7 +477,14 @@ def apply_quant_config_exclude_modules(self): kv_cache_quant_algo = None if quant_config: kv_cache_quant_algo = quant_config.kv_cache_quant_algo - new_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo) + quant_algo = None + activation_scheme = None + exclude_quant_config = quant_config.exclude_quant_config + if exclude_quant_config: + quant_algo = exclude_quant_config.get("quant_algo", None) + activation_scheme = exclude_quant_config.get("activation_scheme", None) + new_config = QuantConfig( + quant_algo=quant_algo, kv_cache_quant_algo=kv_cache_quant_algo, activation_scheme=activation_scheme) if quant_config is not None: if quant_config.exclude_modules is not None: diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 4e0952d7d5..4ac8038030 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -206,7 +206,8 @@ def load_expert_weights_to_dst( load_expert_ids: List[int], dst_w3_w1_weights_tensor: torch.Tensor, dst_w2_weights_tensor: torch.Tensor, dst_w3_w1_bias_tensor: Optional[torch.Tensor], - dst_w2_bias_tensor: Optional[torch.Tensor]): + dst_w2_bias_tensor: Optional[torch.Tensor], + weight_name: str = "weight"): # Multithread weight load is superseded by prefetch_files() in model_engine.py # Also, threading adds overhead in order to protect shuffle index cache with critical section. for local_slot_id, expert_id in enumerate(load_expert_ids): @@ -214,9 +215,9 @@ def load_expert_weights_to_dst( expert_idx = local_slot_id if weight_loading_mode == MoEWeightLoadingMode.VANILLA: - w1_weight = weights[f"{expert_id}.w1.weight"] - w3_weight = weights[f"{expert_id}.w3.weight"] - w2_weight = weights[f"{expert_id}.w2.weight"] + w1_weight = weights[f"{expert_id}.w1.{weight_name}"] + w3_weight = weights[f"{expert_id}.w3.{weight_name}"] + w2_weight = weights[f"{expert_id}.w2.{weight_name}"] if module.bias: w1_bias = weights[f"{expert_id}.w1.bias"] w3_bias = weights[f"{expert_id}.w3.bias"] @@ -251,14 +252,16 @@ def load_expert_weights_to_dst( dst_w2_bias_tensor.data[expert_idx]) def load_weights(self, module: torch.nn.Module, weights: List[Dict], - weight_loading_mode: MoEWeightLoadingMode): + weight_loading_mode: MoEWeightLoadingMode, + weight_name: str = "weight"): self.load_expert_weights_to_dst( module, weights, weight_loading_mode, module.initial_local_expert_ids, module.w3_w1_weight.data, module.w2_weight.data, module.w3_w1_bias.data if module.bias else None, - module.w2_bias.data if module.bias else None) + module.w2_bias.data if module.bias else None, + weight_name) self.load_quant_scales(module, weights) # Re-setup quant scales after loading weights as the tensors may have been modified. @@ -953,6 +956,11 @@ def load_expert_w2_weight(self, module: torch.nn.Module, dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), non_blocking=True) + def load_weights(self, module: torch.nn.Module, weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode, + weight_name: str = "qweight"): + super().load_weights(module, weights, weight_loading_mode, weight_name) + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): assert self.device.type == "cuda" @@ -974,7 +982,13 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): module.fc31_act_scale.data.copy_( torch.ones_like(module.fc31_act_scale) * (1 / all_w3_w1_input_scales_max)) - module.fc31_alpha.data.copy_((torch.ones_like(module.fc31_alpha) * + all_w3_w1_scales_fp8_max = [] + for expert_id in module.initial_local_expert_ids: + w1_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w1.weight_scale"]) + w3_weight_scale_fp8 = load_weight_shard(weights[f"{expert_id}.w3.weight_scale"]) + all_w3_w1_scales_fp8_max.append(torch.max(w3_weight_scale_fp8, w1_weight_scale_fp8)) + all_w3_w1_scales_fp8_max = torch.stack(all_w3_w1_scales_fp8_max).reshape(module.fc31_alpha.shape) + module.fc31_alpha.data.copy_((all_w3_w1_scales_fp8_max * all_w3_w1_input_scales_max).float()) all_w3_scales = [ @@ -985,17 +999,19 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): device=self.device) for expert_id in module.initial_local_expert_ids ] + all_w3_scales = torch.stack(all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) all_w1_scales = [ - load_weight_shard(weights[f"{expert_id}.w1.weight_scale_inv"], + load_weight_shard(weights[f"{expert_id}.w1.weight_scale.int4"], module.tp_size, module.tp_rank, TensorParallelMode.COLUMN, device=self.device) for expert_id in module.initial_local_expert_ids ] + all_w1_scales = torch.stack(all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) all_w3_w1_scales = torch.cat( - [torch.stack(all_w3_scales), - torch.stack(all_w1_scales)], dim=-2) + [all_w3_scales, + all_w1_scales], dim=-2) if module.sm_version == 89: w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(module.dtype) else: @@ -1023,22 +1039,28 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): module.fc2_act_scale.data.copy_( torch.ones_like(module.fc2_act_scale) * (1 / all_w2_input_scales_max)) - module.fc2_alpha.data.copy_((torch.ones_like(module.fc2_alpha) * + all_w2_scales_fp8 = [ + load_weight_shard(weights[f"{expert_id}.w2.weight_scale"]) + for expert_id in module.initial_local_expert_ids + ] + all_w2_scales_fp8 = torch.stack(all_w2_scales_fp8).reshape(module.fc2_alpha.shape) + module.fc2_alpha.data.copy_((all_w2_scales_fp8 * all_w2_input_scales_max).float()) all_w2_scales = [ - load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"], + load_weight_shard(weights[f"{expert_id}.w2.weight_scale.int4"], module.tp_size, module.tp_rank, TensorParallelMode.ROW, device=self.device) for expert_id in module.initial_local_expert_ids ] + all_w2_scales = torch.stack(all_w2_scales) / all_w2_scales_fp8.unsqueeze(2) if module.sm_version == 89: - w2_scales = torch.stack(all_w2_scales).to(torch.float16).view( + w2_scales = all_w2_scales.to(torch.float16).view( module.dtype) else: - w2_scales = torch.stack(all_w2_scales).to(torch.bfloat16).view( + w2_scales = all_w2_scales.to(torch.bfloat16).view( module.dtype) w2_s_shape = w2_scales.shape w2_scales_interleaved = w2_scales.reshape( diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index b2145ac793..928cc12d05 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -25,7 +25,7 @@ from ..logger import logger from ..mapping import Mapping from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM -from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig +from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig, ActivationScheme from ..module import Module from .build_cache import (BuildCache, BuildCacheConfig, CachedStage, get_build_cache_config_from_env) @@ -435,10 +435,43 @@ def _update_from_hf_quant_config(self) -> bool: 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', 'embedding', 'unembedding' ] + elif hf_quant_config.get("quant_method") == "w4a8_awq": + quant_config.quant_algo = QuantAlgo.W4A8_AWQ else: raise NotImplementedError( f"Unsupported quantization_config: {hf_quant_config}.") + # set kv_cache_quant_algo + quant_config.kv_cache_quant_algo = QuantAlgo( + hf_quant_config.get("kv_cache_quant_method").upper() + ) if hf_quant_config.get("kv_cache_quant_method") else None + # set activation_scheme + quant_config.activation_scheme = ActivationScheme( + hf_quant_config.get("activation_scheme").upper() + ) if hf_quant_config.get("activation_scheme") else None + # set exclude_modules + if quant_config.exclude_modules: + if hf_quant_config.get("ignored_modules"): + quant_config.exclude_modules += hf_quant_config.get("ignored_modules") + else: + quant_config.exclude_modules = hf_quant_config.get("ignored_modules") + # set exclude_quant_config + hf_ignored_quantization_config = hf_quant_config.get("ignored_quantization_config") + if hf_ignored_quantization_config: + quant_config.exclude_quant_config = { + "quant_algo": QuantAlgo( + hf_ignored_quantization_config.get("quant_method").upper() + ) if hf_ignored_quantization_config.get("quant_method") else None, + "kv_cache_quant_algo": QuantAlgo( + hf_ignored_quantization_config.get("kv_cache_quant_method").upper() + ) if hf_ignored_quantization_config.get("kv_cache_quant_method") else None, + "activation_scheme": ActivationScheme( + hf_ignored_quantization_config.get("activation_scheme").upper() + ) if hf_ignored_quantization_config.get("activation_scheme") else None, + } + logger.info( + f"Detected quantization_config: {quant_config}." + ) return True return False diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index dcc375320e..d6211cae50 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -45,7 +45,7 @@ WeightOnlyQuantLinear, WeightOnlyQuantRowLinear) from ..quantization.mode import (KV_CACHE_QUANT_ALGO_LIST, QUANT_ALGO_LIST, - W8A8_SQ_PLUGIN_LIST, QuantAlgo) + W8A8_SQ_PLUGIN_LIST, QuantAlgo, ActivationScheme) from ..quantization.utils import fp4_utils from ..top_model_mixin import TopModelMixin from .convert_utils import weight_only_quantize_dict @@ -140,6 +140,8 @@ class QuantConfig: pre_quant_scale (bool): Whether to use pre-quant scale for quantization. Defaults to False. exclude_modules (List[str], optional): The module name patterns that are skipped in quantization. Defaults to None. mamba_ssm_cache_dtype (str, optional): The data type for mamba SSM cache. Defaults to None. + exclude_quant_config (Dict, optional): The model of exclude_modules will use exclude_quant_config. + activation_scheme (tensorrt_llm.quantization.mode.ActivationScheme, optional): The input of activation quantize scheme. """ quant_algo: Optional[QuantAlgo] = None kv_cache_quant_algo: Optional[QuantAlgo] = None @@ -151,6 +153,8 @@ class QuantConfig: pre_quant_scale: bool = False exclude_modules: Optional[List[str]] = None mamba_ssm_cache_dtype: Optional[str] = None + exclude_quant_config: Optional[Dict] = None + activation_scheme: Optional[ActivationScheme] = None @cached_property def quant_mode(self) -> QuantModeWrapper: diff --git a/tensorrt_llm/quantization/mode.py b/tensorrt_llm/quantization/mode.py index a8b38d885f..ab7ee56c95 100644 --- a/tensorrt_llm/quantization/mode.py +++ b/tensorrt_llm/quantization/mode.py @@ -458,3 +458,8 @@ class GroupwiseQuantAlgo: PRE_QUANT_SCALE = 4 W4A8_ALPHA = 8 INT8_WEIGHT = 16 + + +class ActivationScheme(StrEnum, metaclass=BaseEnumMeta): + STATIC = auto() + DYNAMIC = auto()