From 1b52477a4527cffb634455823ab7d78e1bc07214 Mon Sep 17 00:00:00 2001 From: Asma Kuriparambil Thekkumpate Date: Mon, 17 Nov 2025 17:16:55 -0800 Subject: [PATCH 1/2] [2/N] Added KDLoss based AutoQuantize Signed-off-by: Asma Kuriparambil Thekkumpate minor Signed-off-by: Asma Kuriparambil Thekkumpate cheery-picked final PR changes changelog updates Signed-off-by: realAsma minor Signed-off-by: realAsma KL Div formula fix Signed-off-by: realAsma --- CHANGELOG.rst | 1 + examples/llm_eval/lm_eval_hf.py | 14 ++ examples/llm_eval/mmlu.py | 2 + examples/llm_eval/quantization_utils.py | 44 +++- modelopt/torch/quantization/algorithms.py | 229 +++++++++++++++++- modelopt/torch/quantization/model_quant.py | 25 +- .../quantization/plugins/test_huggingface.py | 28 ++- .../unit/torch/quantization/test_autoquant.py | 7 +- 8 files changed, 329 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 899b14009..9bebc387c 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,7 @@ Model Optimizer Changelog (Linux) - Add MoE (e.g. Qwen3-30B-A3B, gpt-oss-20b) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``). - Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md `_ for more details. - Add FP8/NVFP4 KV cache quantization support for Megatron Core models. +- Add KL Divergence loss based auto_quantize method. See `auto_quantize API docs `_ for more details. - Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow. - Add support for PyTorch Geometric quantization. - Add per tensor and per channel MSE calibrator support. diff --git a/examples/llm_eval/lm_eval_hf.py b/examples/llm_eval/lm_eval_hf.py index e980a376e..04c71aecf 100755 --- a/examples/llm_eval/lm_eval_hf.py +++ b/examples/llm_eval/lm_eval_hf.py @@ -53,6 +53,7 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | quant_cfg = arg_dict.pop("quant_cfg", None) auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None) + auto_quantize_method = arg_dict.pop("auto_quantize_method", "gradient") calib_batch_size = arg_dict.pop("calib_batch_size", None) calib_size = arg_dict.pop("calib_size", 512) compress = arg_dict.pop("compress", False) @@ -81,6 +82,7 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | batch_size=calib_batch_size, calib_size=calib_size, auto_quantize_bits=auto_quantize_bits, + auto_quantize_method=auto_quantize_method, test_generated=False, compress=compress, ) @@ -109,6 +111,17 @@ def setup_parser_with_modelopt_args(): "regular quantization will be applied." ), ) + parser.add_argument( + "--auto_quantize_method", + type=str, + default="gradient", + choices=["gradient", "kl_div"], + help=( + "Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method " + "(requires labels in dataset). 'kl_div' uses KL divergence between original and " + "quantized model outputs (no labels required). Default: 'gradient'" + ), + ) parser.add_argument( "--calib_batch_size", type=int, help="Batch size for quantization calibration" ) @@ -139,6 +152,7 @@ def setup_parser_with_modelopt_args(): { "quant_cfg": args.quant_cfg, "auto_quantize_bits": args.auto_quantize_bits, + "auto_quantize_method": args.auto_quantize_method, "calib_batch_size": args.calib_batch_size, "calib_size": args.calib_size, "compress": args.compress, diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index 6a2f70ce4..a702a1da4 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -224,6 +224,7 @@ def main( ntrain: int = 5, quant_cfg: str | None = None, auto_quantize_bits: float | None = None, + auto_quantize_method: str = "gradient", batch_size: int = 0, calib_size: int = 512, dtype: str = "bfloat16", @@ -281,6 +282,7 @@ def main( batch_size=batch_size, calib_size=calib_size, auto_quantize_bits=auto_quantize_bits, + auto_quantize_method=auto_quantize_method, ) for subject in tqdm(subjects): diff --git a/examples/llm_eval/quantization_utils.py b/examples/llm_eval/quantization_utils.py index 2f43c93e0..cd222ed35 100644 --- a/examples/llm_eval/quantization_utils.py +++ b/examples/llm_eval/quantization_utils.py @@ -66,6 +66,7 @@ def _quantize_model_with_dataset( quant_cfg: str | list[str], calib_dataset, auto_quantize_bits=None, + auto_quantize_method="gradient", batch_size=1, compress=False, ): @@ -81,23 +82,41 @@ def _quantize_model_with_dataset( getattr(mtq, quant_fmt) for quant_fmt in quant_cfg if quant_fmt != "NONE" ] - def loss_func(output, data): - # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` - # which contains the loss attribute. - return output.loss + # Configure forward_step and loss_func based on method + if auto_quantize_method == "gradient": + # For gradient-based method, return full output with loss + def forward_step(model, batch): + return model(**batch) + + def loss_func(output, data): + # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` + # which contains the loss attribute. + return output.loss + elif auto_quantize_method == "kl_div": + # For KL divergence method, return only logits + def forward_step(model, batch): + return model(**batch).logits + + loss_func = None # KL divergence doesn't need a custom loss function + else: + raise ValueError( + f"Invalid auto_quantize_method: {auto_quantize_method}. " + "Must be 'gradient' or 'kl_div'" + ) net, _ = mtq.auto_quantize( net, constraints={"effective_bits": auto_quantize_bits}, quantization_formats=quant_cfg_for_search, data_loader=calib_dataset, - forward_step=lambda model, batch: model(**batch), + forward_step=forward_step, loss_func=loss_func, num_calib_steps=len(calib_dataset), num_score_steps=min( len(calib_dataset), 128 // batch_size ), # Limit the number of score steps to avoid long calibration time verbose=True, + method=auto_quantize_method, ) else: mtq_cfg = CUSTOM_CONFIG.get(quant_cfg) # type: ignore [arg-type] @@ -142,6 +161,7 @@ def quantize_model( batch_size, calib_size, auto_quantize_bits=None, + auto_quantize_method="gradient", data="cnn_dailymail", test_generated=True, compress=False, @@ -156,6 +176,7 @@ def quantize_model( batch_size: the calibration batch size for each calibration inference run. calib_size: the total calibration dataset size. auto_quantize_bits: The effective bits constraint for auto_quantize. + auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div'). data: the name of the calibration dataset. test_generated: If ``True``, test the generated text before and after quantization. compress: If ``True``, compress the model after quantization. @@ -180,13 +201,16 @@ def quantize_model( batch_size = get_max_batch_size(net) print(f"Update calib batch {batch_size}") + # Labels are only needed for gradient-based auto_quantize + include_labels = auto_quantize_bits is not None and auto_quantize_method == "gradient" + calib_dataloader = get_dataset_dataloader( dataset_name=data, tokenizer=tokenizer, batch_size=batch_size, num_samples=calib_size, device=device, - include_labels=auto_quantize_bits is not None, + include_labels=include_labels, ) if test_generated: @@ -194,7 +218,13 @@ def quantize_model( generated_str_before_ptq = model.run(input_str) _quantize_model_with_dataset( - model, quant_cfg, calib_dataloader, auto_quantize_bits, batch_size, compress + model, + quant_cfg, + calib_dataloader, + auto_quantize_bits, + auto_quantize_method, + batch_size, + compress, ) if test_generated: diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index b573c77a0..aa68fbfec 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -35,7 +35,7 @@ from modelopt.torch.opt.searcher import LPS, BaseSearcher, SearchConfig, SearchStateDict from modelopt.torch.opt.utils import get_hparam, named_hparams from modelopt.torch.utils import create_param_grad_clear_hook, print_rank_0, report_memory -from modelopt.torch.utils.distributed import DistributedProcessGroup, is_master +from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState, is_master from . import config as mtq_config from . import model_calib @@ -1010,8 +1010,231 @@ def run_search_with_stats(self, max_weight_size, verbose=False): return best_recipes, is_satisfied -class AutoQuantizeLossSearcher(_AutoQuantizeBaseSearcher): - """A searcher for AutoQuantize algorithm that uses loss based score estimation.""" +# TODO: Enable torch compile for this function +# Currently modelopt.onnx is breaking this +def _get_softmax_dist( + logits: torch.Tensor, tp_group, return_log_prob: bool = False +) -> torch.Tensor: + # TODO: test this + dtype = logits.dtype + max_logits = torch.amax(logits, dim=-1, keepdim=True) + torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=tp_group) + logits = (logits - max_logits).float() + sum_exp_logits = torch.exp(torch.logsumexp(logits, dim=-1, keepdim=True)) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=tp_group) + logits = logits - torch.log(sum_exp_logits) + if return_log_prob: + return logits.to(dtype) + else: + return torch.exp(logits).to(dtype) + + +def _get_softmax(logits: torch.Tensor, return_log_prob: bool = False) -> torch.Tensor: + # TODO: do we need to do log_softmax in float32? + # log_softmax is supposed to be numerically stable implementation + log_prob = torch.log_softmax(logits.float(), dim=-1) + if return_log_prob: + return log_prob + else: + return torch.exp(log_prob) + + +def _get_p_log_q(p: torch.Tensor, log_q: torch.Tensor) -> torch.Tensor: + return torch.sum(p * log_q).float() + + +def _get_prob_from_logits( + logits: torch.Tensor, return_log_prob: bool = False, lm_head: nn.Module = None +) -> torch.Tensor: + parallel_state: ParallelState | None = ( + getattr(lm_head, "parallel_state", None) if lm_head is not None else None + ) + if parallel_state is not None and parallel_state.tensor_parallel_group.is_initialized(): + return _get_softmax_dist( + logits, parallel_state.tensor_parallel_group.group, return_log_prob + ) + return _get_softmax(logits, return_log_prob) + + +def _get_kl_div_loss( + prob_unquant: torch.Tensor, logits_quant: torch.Tensor, lm_head: nn.Module = None +) -> torch.Tensor: + log_prob_quant = _get_prob_from_logits(logits_quant, return_log_prob=True, lm_head=lm_head) + # We dont need to calculate the full kl div loss here, just get - p*log_q + return -_get_p_log_q(prob_unquant, log_prob_quant) + + +def _get_lm_head(model: nn.Module) -> nn.Module: + # HF models do allgather of logits to at lm_head + # Hence lm_head outputs are not TP sharded - so we dont need to return the lm_head for TP KLDiv + # Loss + for name, module in model.named_modules(): + if name.endswith("output_layer"): # Megatron models + return module + return None + + +class AutoQuantizeKLDivSearcher(_AutoQuantizeBaseSearcher): + """A searcher for AutoQuantize algorithm that uses KL-Divergence loss based score estimation.""" + + @property + def default_search_config(self): + """Get the default config for the searcher.""" + config = super().default_search_config + config.update( + { + "forward_step": None, + } + ) + return config + + def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: + """Sanitize the search config dict.""" + config = config or {} + for ignored_key in ["score_func", "loss_func", "forward_backward_step"]: + if ignored_key in config: + if config[ignored_key] is not None: + warnings.warn( + f"`{ignored_key}` is ignored for KL-Divergence loss based `auto_quantize`." + ) + config.pop(ignored_key) + config = super().sanitize_search_config(config) + assert config["forward_step"] is not None, ( + "`forward_step` must be provided for KL-Divergence loss based `auto_quantize`. " + "`forward_step(model, data)` should return model logits." + ) + return config + + @torch.inference_mode() + def estimate_sensitivity_scores(self): + """Estimate the sensitivity scores for the model. + + Higher score means more sensitive to quantization. + """ + + def set_to_unquantized(): + for name, hparam in named_hparams(self.model, unique=True): + if not isinstance(hparam, QuantRecipeHparam): + continue + if hparam.is_configurable: + hparam.active = QuantRecipe(quant_cfg=None) + + self.model.eval() + num_iters = self.config["num_score_steps"] + for _, data in tqdm( + zip(range(num_iters), self.config["data_loader"]), + desc="Estimating KLDivergence loss", + total=num_iters, + ): + set_to_unquantized() + logits_unquant = self.config["forward_step"](self.model, data) + prob_unquant = _get_prob_from_logits( + logits_unquant, + return_log_prob=False, + lm_head=_get_lm_head(self.model), + ) + + for name, hparam in tqdm( + list(named_hparams(self.model, configurable=True)), desc="Evaluating hparams" + ): + if not isinstance(hparam, QuantRecipeHparam): + continue + for recipe in hparam.choices: + if recipe == QuantRecipe(quant_cfg=None): + continue + hparam.active = recipe + logits_quant = self.config["forward_step"](self.model, data) + score = _get_kl_div_loss(prob_unquant, logits_quant, _get_lm_head(self.model)) + if hparam._importance_dict[recipe][hparam.score_modules[0]] is None: + hparam._importance_dict[recipe][hparam.score_modules[0]] = score + else: + hparam._importance_dict[recipe][hparam.score_modules[0]] += score + hparam.active = QuantRecipe(quant_cfg=None) + + def run_search_with_stats(self, max_weight_size, verbose=False): + """Run threshold-based binary search for KLDivergence loss based auto_quantize. + + We use binary search to minimize the max(per-layer score) while meeting the constraint. + """ + # Collect all sensitivity scores to determine initial threshold bounds + all_scores = [ + score for name in self.candidate_stats for score in self.candidate_stats[name]["scores"] + ] + + if not all_scores: + warnings.warn("No scores available for threshold-based search!") + is_satisfied = False + return {}, is_satisfied + + # Initialize binary search bounds + min_score = min(all_scores) + max_score = max(all_scores) + threshold = (min_score + max_score) / 2.0 + lower_bound = min_score + upper_bound = max_score + + # Run for fixed number of iterations + max_iterations = 100 + + if verbose: + print_rank_0("AutoQuantize: Starting threshold-based binary search") + print_rank_0(f" Score range: [{min_score:.6e}, {max_score:.6e}]") + print_rank_0(f" Target weight size: {max_weight_size:.2f}") + + for iteration in range(max_iterations): + # Select recipes based on current threshold + best_recipes = {} + total_weight_size = 0.0 + + for name in self.candidate_stats: + formats = self.candidate_stats[name]["formats"] + scores = self.candidate_stats[name]["scores"] + costs = self.candidate_stats[name]["costs"] + + selected_idx = 0 + for idx in range(len(formats)): + if scores[idx] <= threshold: + selected_idx = idx + break + + best_recipes[name] = { + "format": formats[selected_idx], + "costs": costs[selected_idx], + "scores": scores[selected_idx], + } + total_weight_size += costs[selected_idx] + + # Check if we meet the constraint + meets_constraint = total_weight_size <= max_weight_size + + if verbose: + print_rank_0( + f" Iteration {iteration + 1}: threshold={threshold:.6e}, " + f"weight_size={total_weight_size:.2f}, " + f"meets_constraint={meets_constraint}" + ) + + # Update binary search bounds + if meets_constraint: + upper_bound = threshold # Threshold was too aggressive, relax it + else: + lower_bound = threshold # Threshold was too lax, tighten it + + # Update threshold for next iteration + threshold = (lower_bound + upper_bound) / 2.0 + + # Final check if constraint is satisfied + is_satisfied = total_weight_size <= max_weight_size + + if verbose: + print_rank_0( + f"AutoQuantize: Search complete. " + f"Final weight size: {total_weight_size:.2f} " + f"(target: {max_weight_size:.2f}), " + f"constraint satisfied: {is_satisfied}" + ) + + return best_recipes, is_satisfied # Backward compatibility alias (defaults to gradient-based searcher) diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 8cac16cb9..f9c3f9820 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -31,7 +31,7 @@ from modelopt.torch.quantization.config import QuantizeConfig from modelopt.torch.quantization.conversion import set_quantizer_by_cfg -from .algorithms import AutoQuantizeSearcher, QuantRecipe +from .algorithms import AutoQuantizeGradientSearcher, AutoQuantizeKLDivSearcher, QuantRecipe from .config import QuantizeAlgoCfgType from .conversion import set_quantizer_attribute from .mode import QuantizeModeRegistry, get_modelike_from_algo_cfg @@ -252,11 +252,13 @@ def auto_quantize( num_calib_steps: int = 512, num_score_steps: int = 128, verbose: bool = False, + method: str = "gradient", ): r"""Perform optimal per-layer quantization by searching for the best quantization formats per-layer. - ``auto_quantize`` uses a gradient based sensitivity score to rank the per-layer quantization formats and search - for the best quantization formats per-layer. + ``auto_quantize`` uses sensitivity scores to rank the per-layer quantization formats and search + for the best quantization formats per-layer. The sensitivity score can be computed using gradient-based + methods (default) or KL divergence loss, controlled by the ``method`` parameter. Args: model: A pytorch model with quantizer modules. @@ -379,6 +381,13 @@ def forward_backward_step(model, batch) -> None: num_score_steps: Number of batches to use for estimating ``auto_quantize`` scores. Suggested value is 128. A higher value could increase the time taken for performing ``auto_quantize``. verbose: If True, prints the search progress/intermediate results. + method: Method to use for estimating sensitivity loss. Higher loss indicates greater sensitivity + to quantization. Options are: + - ``"gradient"``: (Default) Uses gradient-based loss estimation and linear programming for + search. Requires ``loss_func`` or ``forward_backward_step`` to be provided. + - ``"kl_div"``: Uses KL divergence loss between unquantized and quantized model outputs. Uses + threshold-based binary search. Only requires ``forward_step`` (no loss_func needed). + The ``forward_step`` should return model logits for this method. Returns: A tuple (model, state_dict) where ``model`` is the searched and quantized model and ``state_dict`` contains the history and detailed stats of the search procedure. @@ -441,12 +450,20 @@ def forward_backward_step(model, batch) -> None: processed_quantization_formats.append((quant_cfg, name)) assert len(processed_quantization_formats) > 0, "`quantization_formats` should not be empty" + + # Select the appropriate searcher based on method + if method == "gradient": + searcher = AutoQuantizeGradientSearcher() + elif method == "kl_div": + searcher = AutoQuantizeKLDivSearcher() + else: + raise ValueError(f"Invalid method: {method}. Valid options are 'gradient' or 'kl_div'.") + model = apply_mode( model, mode="auto_quantize", registry=QuantizeModeRegistry, ) - searcher = AutoQuantizeSearcher() search_config = { "quantization_formats": processed_quantization_formats, "data_loader": data_loader, diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index ab59e663e..a3e72ffa7 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -15,6 +15,7 @@ import os import warnings +from contextlib import nullcontext import pytest import torch @@ -136,28 +137,43 @@ def test_dbrx(): assert torch.allclose(out_1[0], out_2[0]) -def test_autoquantize_huggingface(): +@pytest.mark.parametrize( + "method", + ["gradient", "kl_div"], +) +def test_autoquantize_huggingface(method): model = get_tiny_llama() input_ids = model.dummy_inputs["input_ids"] + def forward_step(model, batch): + return model(**batch) if method == "gradient" else model(**batch).logits + warnings.filterwarnings( "error", message="AutoQuantize: Error enabling gradient checkpointing for huggingface model" ) - with pytest.warns( - UserWarning, - match="AutoQuantize: Huggingface model detected - Enabling gradient checkpointing. ", - ): + # Gradient checkpointing warning should only appear for gradient-based method + context = ( + pytest.warns( + UserWarning, + match="AutoQuantize: Huggingface model detected - Enabling gradient checkpointing. ", + ) + if method == "gradient" + else nullcontext() + ) + + with context: best_model, search_history = mtq.auto_quantize( model, constraints={"effective_bits": 11.0}, quantization_formats=[mtq.INT8_DEFAULT_CFG], data_loader=[{"input_ids": input_ids, "labels": input_ids} for _ in range(2)], - forward_step=lambda model, batch: model(**batch), + forward_step=forward_step, loss_func=lambda output, data: output.loss, num_calib_steps=2, num_score_steps=2, verbose=True, + method=method, ) diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index 304402aec..0c41917b4 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -133,7 +133,11 @@ def test_quant_recipe_hparam(): ([None, mtq.INT8_SMOOTHQUANT_CFG], 8.0, 11.0), ], ) -def test_auto_quantize(model_cls, search_formats, min_bits, search_bits): +@pytest.mark.parametrize( + "method", + ["gradient", "kl_div"], +) +def test_auto_quantize(model_cls, search_formats, min_bits, search_bits, method): model = model_cls() def loss_func(output): @@ -149,6 +153,7 @@ def loss_func(output): num_calib_steps=2, num_score_steps=2, verbose=True, + method=method, ) assert isinstance(search_history, dict) assert search_history["best"]["is_satisfied"] From 0aada4ed058b5929270b9fb168ec5207c8e773f0 Mon Sep 17 00:00:00 2001 From: Asma Kuriparambil Thekkumpate Date: Wed, 19 Nov 2025 06:39:36 -0800 Subject: [PATCH 2/2] [3/N] Added autoquantize search state save/restore support Some improvements for KLDiv Signed-off-by: realAsma changelog update Signed-off-by: realAsma minor Signed-off-by: realAsma doc updates Signed-off-by: realAsma --- CHANGELOG.rst | 1 + examples/llm_eval/gen_model_answer.py | 35 ++++++++- examples/llm_eval/lm_eval_hf.py | 25 ++++++- examples/llm_eval/mmlu.py | 6 +- examples/llm_eval/quantization_utils.py | 33 ++++++--- examples/llm_ptq/hf_ptq.py | 73 +++++++++++++++++-- .../llm_ptq/scripts/huggingface_example.sh | 22 ++++++ examples/llm_ptq/scripts/parser.sh | 8 +- modelopt/torch/quantization/algorithms.py | 64 +++++++++------- modelopt/torch/quantization/model_quant.py | 20 ++++- .../unit/torch/quantization/test_autoquant.py | 57 ++++++++++++++- 11 files changed, 292 insertions(+), 52 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9bebc387c..beb01abf0 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,7 @@ Model Optimizer Changelog (Linux) - Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md `_ for more details. - Add FP8/NVFP4 KV cache quantization support for Megatron Core models. - Add KL Divergence loss based auto_quantize method. See `auto_quantize API docs `_ for more details. +- Add support for saving and resuming auto_quantize search state. This speeds up the auto_quantize process by skipping the score estimation step if the search state is provided. - Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow. - Add support for PyTorch Geometric quantization. - Add per tensor and per channel MSE calibrator support. diff --git a/examples/llm_eval/gen_model_answer.py b/examples/llm_eval/gen_model_answer.py index 86504db62..42a7eaac9 100644 --- a/examples/llm_eval/gen_model_answer.py +++ b/examples/llm_eval/gen_model_answer.py @@ -201,8 +201,11 @@ def get_model_answers( tokenizer, args.calib_batch_size, args.calib_size, - args.auto_quantize_bits, test_generated=False, + auto_quantize_bits=args.auto_quantize_bits, + auto_quantize_method=args.auto_quantize_method, + auto_quantize_score_size=args.auto_quantize_score_size, + auto_quantize_checkpoint=args.auto_quantize_checkpoint, ) for question in tqdm(questions): @@ -450,6 +453,36 @@ def reorg_answer_file(answer_file): "regular quantization without auto_quantize search will be applied." ), ) + parser.add_argument( + "--auto_quantize_method", + type=str, + default="gradient", + choices=["gradient", "kl_div"], + help=( + "Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method " + "(requires labels in dataset). 'kl_div' uses KL divergence between original and " + "quantized model outputs (no labels required). Default: 'gradient'" + ), + ) + parser.add_argument( + "--auto_quantize_score_size", + type=int, + default=128, + help=( + "Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on " + "sensitivity score estimation, so reducing this speeds it up while only minimally affecting " + "final model accuracy compared to lowering --calib_size (the number of samples used for calibration)." + ), + ) + parser.add_argument( + "--auto_quantize_checkpoint", + type=str, + default=None, + help=( + "Path to checkpoint file for saving/restoring auto_quantize search state " + "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified." + ), + ) parser.add_argument( "--trust_remote_code", help="Set trust_remote_code for Huggingface models and tokenizers", diff --git a/examples/llm_eval/lm_eval_hf.py b/examples/llm_eval/lm_eval_hf.py index 04c71aecf..31103ff86 100755 --- a/examples/llm_eval/lm_eval_hf.py +++ b/examples/llm_eval/lm_eval_hf.py @@ -54,6 +54,8 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | quant_cfg = arg_dict.pop("quant_cfg", None) auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None) auto_quantize_method = arg_dict.pop("auto_quantize_method", "gradient") + auto_quantize_score_size = arg_dict.pop("auto_quantize_score_size", 128) + auto_quantize_checkpoint = arg_dict.pop("auto_quantize_checkpoint", None) calib_batch_size = arg_dict.pop("calib_batch_size", None) calib_size = arg_dict.pop("calib_size", 512) compress = arg_dict.pop("compress", False) @@ -83,8 +85,10 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | calib_size=calib_size, auto_quantize_bits=auto_quantize_bits, auto_quantize_method=auto_quantize_method, + auto_quantize_score_size=auto_quantize_score_size, test_generated=False, compress=compress, + auto_quantize_checkpoint=auto_quantize_checkpoint, ) return model_obj @@ -103,6 +107,12 @@ def setup_parser_with_modelopt_args(): "comma-separated list of quantization quantization formats that will be searched by `auto_quantize`" ), ) + parser.add_argument( + "--calib_batch_size", type=int, help="Batch size for quantization calibration" + ) + parser.add_argument( + "--calib_size", type=int, help="Calibration size for quantization", default=512 + ) parser.add_argument( "--auto_quantize_bits", type=float, @@ -123,10 +133,19 @@ def setup_parser_with_modelopt_args(): ), ) parser.add_argument( - "--calib_batch_size", type=int, help="Batch size for quantization calibration" + "--auto_quantize_score_size", + type=int, + default=128, + help=( + "Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on " + "sensitivity score estimation, so reducing this speeds it up while only minimally affecting " + "final model accuracy compared to lowering --calib_size (the number of samples used for calibration)." + ), ) parser.add_argument( - "--calib_size", type=int, help="Calibration size for quantization", default=512 + "--auto_quantize_checkpoint", + type=str, + help=("Path to checkpoint file for saving/restoring auto_quantize search state. "), ) parser.add_argument( "--compress", @@ -153,6 +172,8 @@ def setup_parser_with_modelopt_args(): "quant_cfg": args.quant_cfg, "auto_quantize_bits": args.auto_quantize_bits, "auto_quantize_method": args.auto_quantize_method, + "auto_quantize_score_size": args.auto_quantize_score_size, + "auto_quantize_checkpoint": args.auto_quantize_checkpoint, "calib_batch_size": args.calib_batch_size, "calib_size": args.calib_size, "compress": args.compress, diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index a702a1da4..ca244052b 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -224,10 +224,12 @@ def main( ntrain: int = 5, quant_cfg: str | None = None, auto_quantize_bits: float | None = None, - auto_quantize_method: str = "gradient", batch_size: int = 0, calib_size: int = 512, dtype: str = "bfloat16", + auto_quantize_method: str = "gradient", + auto_quantize_score_size: int = 128, + auto_quantize_checkpoint: str | None = None, **kwargs, ): random.seed(RAND_SEED) @@ -283,6 +285,8 @@ def main( calib_size=calib_size, auto_quantize_bits=auto_quantize_bits, auto_quantize_method=auto_quantize_method, + auto_quantize_score_size=auto_quantize_score_size, + auto_quantize_checkpoint=auto_quantize_checkpoint, ) for subject in tqdm(subjects): diff --git a/examples/llm_eval/quantization_utils.py b/examples/llm_eval/quantization_utils.py index cd222ed35..3df44115a 100644 --- a/examples/llm_eval/quantization_utils.py +++ b/examples/llm_eval/quantization_utils.py @@ -67,8 +67,10 @@ def _quantize_model_with_dataset( calib_dataset, auto_quantize_bits=None, auto_quantize_method="gradient", + auto_quantize_score_size=128, batch_size=1, compress=False, + auto_quantize_checkpoint=None, ): if hasattr(lm, "gpt2"): net = lm.gpt2 @@ -112,11 +114,12 @@ def forward_step(model, batch): forward_step=forward_step, loss_func=loss_func, num_calib_steps=len(calib_dataset), - num_score_steps=min( - len(calib_dataset), 128 // batch_size - ), # Limit the number of score steps to avoid long calibration time + # Most time is spent on score estimation; fewer samples speed it up with little accuracy impact. + num_score_steps=min(len(calib_dataset), max(auto_quantize_score_size // batch_size, 1)), verbose=True, method=auto_quantize_method, + # disabled_layers=["*lm_head*", "*mlp.gate.*"], + checkpoint=auto_quantize_checkpoint, ) else: mtq_cfg = CUSTOM_CONFIG.get(quant_cfg) # type: ignore [arg-type] @@ -160,11 +163,13 @@ def quantize_model( tokenizer, batch_size, calib_size, - auto_quantize_bits=None, - auto_quantize_method="gradient", data="cnn_dailymail", test_generated=True, compress=False, + auto_quantize_bits=None, + auto_quantize_method="gradient", + auto_quantize_score_size=128, + auto_quantize_checkpoint=None, ): """Quantizes the model with the provided calibration dataset. @@ -175,11 +180,14 @@ def quantize_model( tokenizer: the tokenizer. batch_size: the calibration batch size for each calibration inference run. calib_size: the total calibration dataset size. - auto_quantize_bits: The effective bits constraint for auto_quantize. - auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div'). data: the name of the calibration dataset. test_generated: If ``True``, test the generated text before and after quantization. compress: If ``True``, compress the model after quantization. + auto_quantize_bits: The effective bits constraint for auto_quantize. + auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div'). + auto_quantize_score_size: Number of samples used for auto_quantize scoring. + auto_quantize_checkpoint: Path to checkpoint file for saving/restoring auto_quantize search state + (sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified. """ if "AWQ" in quant_cfg: print( @@ -191,8 +199,10 @@ def quantize_model( if hasattr(model, "model"): device = model.model.device + is_gradient_based = auto_quantize_bits is not None and auto_quantize_method == "gradient" + if batch_size == 0: - if auto_quantize_bits is not None or torch.distributed.is_initialized(): + if is_gradient_based or torch.distributed.is_initialized(): raise ValueError("We dont support automatic batch size inference for this case.") net = model.gpt2 if hasattr(model, "gpt2") else model.model @@ -201,16 +211,13 @@ def quantize_model( batch_size = get_max_batch_size(net) print(f"Update calib batch {batch_size}") - # Labels are only needed for gradient-based auto_quantize - include_labels = auto_quantize_bits is not None and auto_quantize_method == "gradient" - calib_dataloader = get_dataset_dataloader( dataset_name=data, tokenizer=tokenizer, batch_size=batch_size, num_samples=calib_size, device=device, - include_labels=include_labels, + include_labels=is_gradient_based, ) if test_generated: @@ -223,8 +230,10 @@ def quantize_model( calib_dataloader, auto_quantize_bits, auto_quantize_method, + auto_quantize_score_size, batch_size, compress, + auto_quantize_checkpoint, ) if test_generated: diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 7bb8d0f28..f0bb56caa 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -95,7 +95,15 @@ def auto_quantize( - model, qformat, auto_quantize_bits, calib_dataloader, calibrate_loop, batch_size=1 + model, + qformat, + calib_dataloader, + calibrate_loop, + auto_quantize_bits, + batch_size=1, + auto_quantize_method="gradient", + auto_quantize_score_size=128, + auto_quantize_checkpoint=None, ): qformat_list = qformat.split(",") assert qformat_list, "No quantization formats provided" @@ -122,18 +130,34 @@ def loss_func(output, data): # which contains the loss attribute. return output.loss + if auto_quantize_method == "gradient": + # For gradient-based method, return full output with loss + def forward_step(model, batch): + return model(**batch) + elif auto_quantize_method == "kl_div": + # For KL divergence method, return only logits + def forward_step(model, batch): + return model(**batch).logits + else: + raise ValueError( + f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" + ) + model, _ = mtq.auto_quantize( model, constraints={"effective_bits": auto_quantize_bits}, data_loader=calib_dataloader, - forward_step=lambda model, batch: model(**batch), - loss_func=loss_func, + forward_step=forward_step, + loss_func=loss_func, # Only used for gradient-based method # TRTLLM only support one quantization format or None (do not quantize, internally supported) quantization_formats=[QUANT_CFG_CHOICES[format] for format in qformat_list], num_calib_steps=len(calib_dataloader), - num_score_steps=len(calib_dataloader), + # AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration. + num_score_steps=min(len(calib_dataloader), max(auto_quantize_score_size // batch_size, 1)), verbose=True, disabled_layers=["*lm_head*"], + method=auto_quantize_method, + checkpoint=auto_quantize_checkpoint, ) # We need to explicitly calibrate for kv cache quantization @@ -191,10 +215,13 @@ def quantize_model(model, quant_cfg, args, calib_dataloader=None, calibration_on model = auto_quantize( model, args.qformat, - args.auto_quantize_bits, calib_dataloader, calibrate_loop, + args.auto_quantize_bits, args.batch_size, + args.auto_quantize_method, + args.auto_quantize_score_size, + args.auto_quantize_checkpoint, ) elif calibration_only: model = mtq.calibrate(model, quant_cfg["algorithm"], forward_loop=calibrate_loop) @@ -444,13 +471,17 @@ def main(args): assert tokenizer is not None and isinstance( tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) ), "The PreTrainedTokenizer must be set" + # Labels are only needed for gradient-based auto_quantize + include_labels = ( + args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient" + ) calib_dataloader = get_dataset_dataloader( dataset_name=args.dataset, tokenizer=tokenizer, batch_size=args.batch_size, num_samples=args.calib_size, device=device, - include_labels=args.auto_quantize_bits is not None, + include_labels=include_labels, ) quant_cfg = build_quant_cfg( @@ -803,6 +834,36 @@ def output_decode(generated_ids, input_shape): default=None, type=str, ) + parser.add_argument( + "--auto_quantize_method", + type=str, + default="gradient", + choices=["gradient", "kl_div"], + help=( + "Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method " + "(requires labels in dataset). 'kl_div' uses KL divergence between original and " + "quantized model outputs (no labels required). Default: 'gradient'" + ), + ) + parser.add_argument( + "--auto_quantize_score_size", + type=int, + default=128, + help=( + "Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on " + "sensitivity score estimation, so reducing this speeds it up while only minimally affecting " + "final model accuracy compared to lowering --calib_size (the number of samples used for calibration)." + ), + ) + parser.add_argument( + "--auto_quantize_checkpoint", + type=str, + default=None, + help=( + "Path to checkpoint file for saving/restoring auto_quantize search state " + "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified." + ), + ) args = parser.parse_args() diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 7b7d6910e..043b690e5 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -93,6 +93,28 @@ fi if [ -n "$AUTO_QUANTIZE_BITS" ]; then PTQ_ARGS+=" --auto_quantize_bits=$AUTO_QUANTIZE_BITS " fi + +if [ -n "$AUTO_QUANTIZE_METHOD" ]; then + PTQ_ARGS+=" --auto_quantize_method=$AUTO_QUANTIZE_METHOD " +fi + +if [ -n "$AUTO_QUANTIZE_SCORE_SIZE" ]; then + PTQ_ARGS+=" --auto_quantize_score_size=$AUTO_QUANTIZE_SCORE_SIZE " +fi + +# Automatically generate auto_quantize checkpoint path if not provided +if [ -n "$AUTO_QUANTIZE_BITS" ] && [ -z "$AUTO_QUANTIZE_CHECKPOINT" ]; then + # Create a descriptive checkpoint name based on model and quantization settings + AQ_METHOD=${AUTO_QUANTIZE_METHOD:-gradient} + AUTO_QUANTIZE_CHECKPOINT="${ROOT_SAVE_PATH}/auto_quantize_checkpoints/${MODEL_NAME}_${AQ_METHOD}.pth" + mkdir -p $(dirname $AUTO_QUANTIZE_CHECKPOINT) + echo "Auto-generated auto_quantize checkpoint path: $AUTO_QUANTIZE_CHECKPOINT" +fi + +if [ -n "$AUTO_QUANTIZE_BITS" ]; then + PTQ_ARGS+=" --auto_quantize_checkpoint=$AUTO_QUANTIZE_CHECKPOINT " +fi + if [ -n "$CALIB_DATASET" ]; then PTQ_ARGS+=" --dataset=$CALIB_DATASET " fi diff --git a/examples/llm_ptq/scripts/parser.sh b/examples/llm_ptq/scripts/parser.sh index 7df601327..8db2fe131 100644 --- a/examples/llm_ptq/scripts/parser.sh +++ b/examples/llm_ptq/scripts/parser.sh @@ -36,7 +36,7 @@ parse_options() { USE_SEQ_DEVICE_MAP=false # Parse command-line options - ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:" -n "$0" -- "$@") + ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:" -n "$0" -- "$@") eval set -- "$ARGS" while true; do @@ -65,6 +65,9 @@ parse_options() { --low_memory_mode ) LOW_MEMORY_MODE=true; shift;; --calib_dataset ) CALIB_DATASET="$2"; shift 2;; --calib_seq ) CALIB_SEQ="$2"; shift 2;; + --auto_quantize_method ) AUTO_QUANTIZE_METHOD="$2"; shift 2;; + --auto_quantize_score_size ) AUTO_QUANTIZE_SCORE_SIZE="$2"; shift 2;; + --auto_quantize_checkpoint ) AUTO_QUANTIZE_CHECKPOINT="$2"; shift 2;; -- ) shift; break ;; * ) break ;; esac @@ -150,5 +153,8 @@ parse_options() { echo "low_memory_mode: $LOW_MEMORY_MODE" echo "calib_dataset: $CALIB_DATASET" echo "calib_seq: $CALIB_SEQ" + echo "auto_quantize_method: $AUTO_QUANTIZE_METHOD" + echo "auto_quantize_score_size: $AUTO_QUANTIZE_SCORE_SIZE" + echo "auto_quantize_checkpoint: $AUTO_QUANTIZE_CHECKPOINT" echo "=================" } diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index aa68fbfec..16abc6ef4 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -364,7 +364,6 @@ def default_state_dict(self) -> SearchStateDict: return { "candidate_stats": defaultdict(dict), "best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False}, - "constraints": {}, } def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: @@ -401,13 +400,14 @@ def _apply_quant_group_rule(self, name: str, rule) -> str | None: Args: name: Module name - rule: Either a regex pattern string or a callable that returns a unique key + rule: Either a regex pattern string or a callable that returns a unique key; + If callable, it should take the model and the name as input and return the unique key Returns: The group key if the rule matches, None otherwise """ if callable(rule): - return rule(name) + return rule(self.model, name) else: # Regex pattern pattern = re.compile(rule) @@ -421,13 +421,14 @@ def _apply_score_group_rule(self, name: str, rule) -> str | None: Args: name: Module name - rule: Either a regex pattern string or a callable that returns the score module name + rule: Either a regex pattern string or a callable that returns the score module name. + If callable, it should take the model and the name as input and return the score module name Returns: The score module name if the rule matches, None otherwise """ if callable(rule): - return rule(name) + return rule(self.model, name) else: # Regex pattern - return the matched name or full match pattern = re.compile(rule) @@ -546,6 +547,29 @@ def _verify_constraint(self, search_recipes): def estimate_sensitivity_scores(self) -> None: """Estimate sensitivity scores and track them with Hparam.""" + def initialize_candidate_stats(self): + """Initialize the candidate stats for the model.""" + for name, hparam in named_hparams(self.model, unique=True): + if not isinstance(hparam, QuantRecipeHparam): + continue + + formats, scores, costs = [], [], [] + prev_score = float("inf") + for recipe in hparam.choices: + formats.append(recipe) + + score = hparam.get_score(recipe) # type: ignore [arg-type] + cost = hparam.get_cost(recipe) # type: ignore [arg-type] + + score = min(score, prev_score) # TODO: Should we get rid of this? + scores.append(score) + costs.append(cost) + prev_score = score + + self.candidate_stats[name]["formats"] = formats + self.candidate_stats[name]["scores"] = scores + self.candidate_stats[name]["costs"] = costs + def _run_func(self, func, num_iters=1, desc=""): for i, data in tqdm( zip(range(num_iters), self.config["data_loader"]), @@ -604,7 +628,15 @@ def forward_loop(model): # TODO: This is a hack. We need to create a mode for auto_quantize to handle this in a clean way. ModeloptStateManager(self.model).state_dict().pop() + if self.candidate_stats: + if self.config["verbose"]: + print_rank_0("AutoQuantize: Restored from checkpoint, skipping scoring") + return + self.estimate_sensitivity_scores() + self.initialize_candidate_stats() + # Save checkpoint after successful score estimation + self.save_search_checkpoint(verbose=self.config["verbose"]) @staticmethod def _get_total_weight_size(modules): @@ -642,27 +674,6 @@ def run_search(self): total_weight_size = self._get_total_weight_size(self.model.modules()) max_weight_size = total_weight_size * compression - for name, hparam in named_hparams(self.model, unique=True): - if not isinstance(hparam, QuantRecipeHparam): - continue - - formats, scores, costs = [], [], [] - prev_score = float("inf") - for recipe in hparam.choices: - formats.append(recipe) - - score = hparam.get_score(recipe) # type: ignore [arg-type] - cost = hparam.get_cost(recipe) # type: ignore [arg-type] - - score = min(score, prev_score) # TODO: Should we get rid of this? - scores.append(score) - costs.append(cost) - prev_score = score - - self.candidate_stats[name]["formats"] = formats - self.candidate_stats[name]["scores"] = scores - self.candidate_stats[name]["costs"] = costs - # Run the search with stats to get the best recipe and whether the constraints are satisfied best_recipe_info, is_satisfied = self.run_search_with_stats(max_weight_size, verbose) self.best["is_satisfied"] = is_satisfied @@ -813,6 +824,7 @@ def forward_backward_step(model, data): raise RuntimeError( "AutoQuantize: Error while calling `backward()` on the loss returned by `loss_func`. " "Please fix this!" + f"error: {e}" ) from e return forward_backward_step diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index f9c3f9820..983d27ee8 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -253,6 +253,7 @@ def auto_quantize( num_score_steps: int = 128, verbose: bool = False, method: str = "gradient", + checkpoint: str | None = None, ): r"""Perform optimal per-layer quantization by searching for the best quantization formats per-layer. @@ -260,6 +261,15 @@ def auto_quantize( for the best quantization formats per-layer. The sensitivity score can be computed using gradient-based methods (default) or KL divergence loss, controlled by the ``method`` parameter. + Internally this API runs two main phases: + + #. Calibrate the quantized model exactly like :func:`quantize` would. + #. Estimate per-layer sensitivity scores to decide which format to keep. + + The sensitivity scoring phase typically dominates the runtime of ``auto_quantize``, so decreasing the number of + samples used for scoring (see ``num_score_steps``) is the recommended way for improving overall auto_quantize time + with minimal accuracy impact. + Args: model: A pytorch model with quantizer modules. constraints: Constraints for the search. Currently we support only ``effective_bits``. @@ -377,9 +387,11 @@ def forward_backward_step(model, batch) -> None: disabled_layers = "*lm_head*" disabled_layers = ["*lm_head*", "*mlp*"] - num_calib_steps: Number of batches to use for calibrating the quantized model. Suggested value is 512. + num_calib_steps: Number of batches to use for calibrating each candidate quantization format. Suggested value + is 512. num_score_steps: Number of batches to use for estimating ``auto_quantize`` scores. Suggested value is 128. - A higher value could increase the time taken for performing ``auto_quantize``. + A higher value could increase the time taken for performing ``auto_quantize``; reducing it speeds up the + sensitivity score estimation phase and typically affects accuracy less than lowering ``num_calib_steps``. verbose: If True, prints the search progress/intermediate results. method: Method to use for estimating sensitivity loss. Higher loss indicates greater sensitivity to quantization. Options are: @@ -388,6 +400,9 @@ def forward_backward_step(model, batch) -> None: - ``"kl_div"``: Uses KL divergence loss between unquantized and quantized model outputs. Uses threshold-based binary search. Only requires ``forward_step`` (no loss_func needed). The ``forward_step`` should return model logits for this method. + checkpoint: (Optional) Path to checkpoint file for saving/restoring auto_quantize search state. + If the checkpoint file exists, the search state will be restored from it, skipping the + expensive score estimation step. Returns: A tuple (model, state_dict) where ``model`` is the searched and quantized model and ``state_dict`` contains the history and detailed stats of the search procedure. @@ -474,6 +489,7 @@ def forward_backward_step(model, batch) -> None: "num_score_steps": num_score_steps, "disabled_layers": disabled_layers, "verbose": verbose, + "checkpoint": checkpoint, } # Disable all quantizers; AutoQuantize will enable the needed ones set_quantizer_by_cfg(model, {"*": {"enable": False}}) diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index 0c41917b4..1a5cfee32 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -183,7 +183,7 @@ def loss_func(output): assert torch.allclose(output_ref, output_test) -def test_auto_quantize_disable(): +def test_auto_quantize_disable_layers(): model = TransformerBlock() def loss_func(output): @@ -342,3 +342,58 @@ def test_estimate_quant_compression(): fp8_affine_kv_cfg = mtq.config.QuantizeConfig(**mtq.FP8_AFFINE_KV_CFG) assert estimate_quant_compression(fp8_affine_kv_cfg) == 0.5 + + +@pytest.mark.parametrize("method", ["gradient", "kl_div"]) +def test_auto_quantize_checkpoint_resume(method, tmp_path, capsys): + """Test that checkpoint can be used to resume an interrupted search.""" + model = SimpleLinear() + checkpoint_path = str(tmp_path / "autoquant_resume_checkpoint.pth") + + # First run: save checkpoint + model_1, state_dict_1 = mtq.auto_quantize( + model, + constraints={"effective_bits": 6.0}, + quantization_formats=[mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], + data_loader=[model.get_input() for _ in range(2)], + forward_step=lambda model, batch: model(batch), + loss_func=lambda output, data: output.sum(), + num_calib_steps=2, + num_score_steps=2, + verbose=True, + method=method, + checkpoint=checkpoint_path, + ) + + # Clear captured output from first run + capsys.readouterr() + + # Second run: resume with same constraint should produce same results + model_2 = SimpleLinear() + model_2, state_dict_2 = mtq.auto_quantize( + model_2, + constraints={"effective_bits": 6.0}, # Same constraint + quantization_formats=[mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], + data_loader=[model_2.get_input() for _ in range(2)], + forward_step=lambda model, batch: model(batch), + loss_func=lambda output, data: output.sum(), + num_calib_steps=2, + num_score_steps=2, + verbose=True, + method=method, + checkpoint=checkpoint_path, + ) + + # Verify the restore message was printed on second run + captured = capsys.readouterr() + assert "Restored from checkpoint, skipping scoring" in captured.out, ( + "Expected restore message when resuming from checkpoint" + ) + + # Results should be identical when using same constraint + assert state_dict_1["candidate_stats"] == state_dict_2["candidate_stats"] + assert state_dict_1["best"]["recipe"] == state_dict_2["best"]["recipe"] + assert ( + pytest.approx(state_dict_1["best"]["constraints"]["effective_bits"]) + == state_dict_2["best"]["constraints"]["effective_bits"] + )