9595
9696
9797def auto_quantize (
98- model , qformat , auto_quantize_bits , calib_dataloader , calibrate_loop , batch_size = 1
98+ model ,
99+ qformat ,
100+ calib_dataloader ,
101+ calibrate_loop ,
102+ auto_quantize_bits ,
103+ batch_size = 1 ,
104+ auto_quantize_method = "gradient" ,
105+ auto_quantize_score_size = 128 ,
106+ auto_quantize_checkpoint = None ,
99107):
100108 qformat_list = qformat .split ("," )
101109 assert qformat_list , "No quantization formats provided"
@@ -122,18 +130,33 @@ def loss_func(output, data):
122130 # which contains the loss attribute.
123131 return output .loss
124132
133+ if auto_quantize_method == "gradient" :
134+ # For gradient-based method, return full output with loss
135+ def forward_step (model , batch ):
136+ return model (** batch )
137+ elif auto_quantize_method == "kl_div" :
138+ # For KL divergence method, return only logits
139+ def forward_step (model , batch ):
140+ return model (** batch ).logits
141+ else :
142+ raise ValueError (
143+ f"Invalid auto_quantize_method: { auto_quantize_method } . Must be 'gradient' or 'kl_div'"
144+ )
145+
125146 model , _ = mtq .auto_quantize (
126147 model ,
127148 constraints = {"effective_bits" : auto_quantize_bits },
128149 data_loader = calib_dataloader ,
129- forward_step = lambda model , batch : model ( ** batch ) ,
130- loss_func = loss_func ,
150+ forward_step = forward_step ,
151+ loss_func = loss_func , # Only used for gradient-based method
131152 # TRTLLM only support one quantization format or None (do not quantize, internally supported)
132153 quantization_formats = [QUANT_CFG_CHOICES [format ] for format in qformat_list ],
133154 num_calib_steps = len (calib_dataloader ),
134- num_score_steps = len (calib_dataloader ),
155+ num_score_steps = min ( len (calib_dataloader ), max ( auto_quantize_score_size // batch_size , 1 ) ),
135156 verbose = True ,
136157 disabled_layers = ["*lm_head*" ],
158+ method = auto_quantize_method ,
159+ checkpoint = auto_quantize_checkpoint ,
137160 )
138161
139162 # We need to explicitly calibrate for kv cache quantization
@@ -191,10 +214,13 @@ def quantize_model(model, quant_cfg, args, calib_dataloader=None, calibration_on
191214 model = auto_quantize (
192215 model ,
193216 args .qformat ,
194- args .auto_quantize_bits ,
195217 calib_dataloader ,
196218 calibrate_loop ,
219+ args .auto_quantize_bits ,
197220 args .batch_size ,
221+ args .auto_quantize_method ,
222+ args .auto_quantize_score_size ,
223+ args .auto_quantize_checkpoint ,
198224 )
199225 elif calibration_only :
200226 model = mtq .calibrate (model , quant_cfg ["algorithm" ], forward_loop = calibrate_loop )
@@ -444,13 +470,17 @@ def main(args):
444470 assert tokenizer is not None and isinstance (
445471 tokenizer , (PreTrainedTokenizer , PreTrainedTokenizerFast )
446472 ), "The PreTrainedTokenizer must be set"
473+ # Labels are only needed for gradient-based auto_quantize
474+ include_labels = (
475+ args .auto_quantize_bits is not None and args .auto_quantize_method == "gradient"
476+ )
447477 calib_dataloader = get_dataset_dataloader (
448478 dataset_name = args .dataset ,
449479 tokenizer = tokenizer ,
450480 batch_size = args .batch_size ,
451481 num_samples = args .calib_size ,
452482 device = device ,
453- include_labels = args . auto_quantize_bits is not None ,
483+ include_labels = include_labels ,
454484 )
455485
456486 quant_cfg = build_quant_cfg (
@@ -803,6 +833,35 @@ def output_decode(generated_ids, input_shape):
803833 default = None ,
804834 type = str ,
805835 )
836+ parser .add_argument (
837+ "--auto_quantize_method" ,
838+ type = str ,
839+ default = "gradient" ,
840+ choices = ["gradient" , "kl_div" ],
841+ help = (
842+ "Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
843+ "(requires labels in dataset). 'kl_div' uses KL divergence between original and "
844+ "quantized model outputs (no labels required). Default: 'gradient'"
845+ ),
846+ )
847+ parser .add_argument (
848+ "--auto_quantize_score_size" ,
849+ type = int ,
850+ default = 128 ,
851+ help = (
852+ "Number of samples to use for scoring in auto_quantize. Default: 128. "
853+ "Higher values improve accuracy but increase time."
854+ ),
855+ )
856+ parser .add_argument (
857+ "--auto_quantize_checkpoint" ,
858+ type = str ,
859+ default = None ,
860+ help = (
861+ "Path to checkpoint file for saving/restoring auto_quantize search state "
862+ "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
863+ ),
864+ )
806865
807866 args = parser .parse_args ()
808867
0 commit comments