Skip to content

Commit 48b0423

Browse files
Asma Kuriparambil ThekkumpaterealAsma
authored andcommitted
[2/N] Added KDLoss based AutoQuantize
Signed-off-by: Asma Kuriparambil Thekkumpate <[email protected]> minor Signed-off-by: Asma Kuriparambil Thekkumpate <[email protected]> cheery-picked final PR changes changelog updates Signed-off-by: realAsma <[email protected]>
1 parent b7bd107 commit 48b0423

File tree

8 files changed

+328
-21
lines changed

8 files changed

+328
-21
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Model Optimizer Changelog (Linux)
1414
- Add MoE (e.g. Qwen3-30B-A3B, gpt-oss-20b) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``).
1515
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
1616
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
17+
- Add KL Divergence loss based auto_quantize method. See `auto_quantize API docs <https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.quantization.model_quant.html#modelopt.torch.quantization.model_quant.auto_quantize>`_ for more details.
1718
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
1819
- Add support for PyTorch Geometric quantization.
1920

examples/llm_eval/lm_eval_hf.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
5353

5454
quant_cfg = arg_dict.pop("quant_cfg", None)
5555
auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None)
56+
auto_quantize_method = arg_dict.pop("auto_quantize_method", "gradient")
5657
calib_batch_size = arg_dict.pop("calib_batch_size", None)
5758
calib_size = arg_dict.pop("calib_size", 512)
5859
compress = arg_dict.pop("compress", False)
@@ -81,6 +82,7 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
8182
batch_size=calib_batch_size,
8283
calib_size=calib_size,
8384
auto_quantize_bits=auto_quantize_bits,
85+
auto_quantize_method=auto_quantize_method,
8486
test_generated=False,
8587
compress=compress,
8688
)
@@ -109,6 +111,17 @@ def setup_parser_with_modelopt_args():
109111
"regular quantization will be applied."
110112
),
111113
)
114+
parser.add_argument(
115+
"--auto_quantize_method",
116+
type=str,
117+
default="gradient",
118+
choices=["gradient", "kl_div"],
119+
help=(
120+
"Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
121+
"(requires labels in dataset). 'kl_div' uses KL divergence between original and "
122+
"quantized model outputs (no labels required). Default: 'gradient'"
123+
),
124+
)
112125
parser.add_argument(
113126
"--calib_batch_size", type=int, help="Batch size for quantization calibration"
114127
)
@@ -139,6 +152,7 @@ def setup_parser_with_modelopt_args():
139152
{
140153
"quant_cfg": args.quant_cfg,
141154
"auto_quantize_bits": args.auto_quantize_bits,
155+
"auto_quantize_method": args.auto_quantize_method,
142156
"calib_batch_size": args.calib_batch_size,
143157
"calib_size": args.calib_size,
144158
"compress": args.compress,

examples/llm_eval/mmlu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def main(
224224
ntrain: int = 5,
225225
quant_cfg: str | None = None,
226226
auto_quantize_bits: float | None = None,
227+
auto_quantize_method: str = "gradient",
227228
batch_size: int = 0,
228229
calib_size: int = 512,
229230
dtype: str = "bfloat16",
@@ -281,6 +282,7 @@ def main(
281282
batch_size=batch_size,
282283
calib_size=calib_size,
283284
auto_quantize_bits=auto_quantize_bits,
285+
auto_quantize_method=auto_quantize_method,
284286
)
285287

286288
for subject in tqdm(subjects):

examples/llm_eval/quantization_utils.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def _quantize_model_with_dataset(
6666
quant_cfg: str | list[str],
6767
calib_dataset,
6868
auto_quantize_bits=None,
69+
auto_quantize_method="gradient",
6970
batch_size=1,
7071
compress=False,
7172
):
@@ -81,23 +82,41 @@ def _quantize_model_with_dataset(
8182
getattr(mtq, quant_fmt) for quant_fmt in quant_cfg if quant_fmt != "NONE"
8283
]
8384

84-
def loss_func(output, data):
85-
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
86-
# which contains the loss attribute.
87-
return output.loss
85+
# Configure forward_step and loss_func based on method
86+
if auto_quantize_method == "gradient":
87+
# For gradient-based method, return full output with loss
88+
def forward_step(model, batch):
89+
return model(**batch)
90+
91+
def loss_func(output, data):
92+
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
93+
# which contains the loss attribute.
94+
return output.loss
95+
elif auto_quantize_method == "kl_div":
96+
# For KL divergence method, return only logits
97+
def forward_step(model, batch):
98+
return model(**batch).logits
99+
100+
loss_func = None # KL divergence doesn't need a custom loss function
101+
else:
102+
raise ValueError(
103+
f"Invalid auto_quantize_method: {auto_quantize_method}. "
104+
"Must be 'gradient' or 'kl_div'"
105+
)
88106

89107
net, _ = mtq.auto_quantize(
90108
net,
91109
constraints={"effective_bits": auto_quantize_bits},
92110
quantization_formats=quant_cfg_for_search,
93111
data_loader=calib_dataset,
94-
forward_step=lambda model, batch: model(**batch),
112+
forward_step=forward_step,
95113
loss_func=loss_func,
96114
num_calib_steps=len(calib_dataset),
97115
num_score_steps=min(
98116
len(calib_dataset), 128 // batch_size
99117
), # Limit the number of score steps to avoid long calibration time
100118
verbose=True,
119+
method=auto_quantize_method,
101120
)
102121
else:
103122
mtq_cfg = CUSTOM_CONFIG.get(quant_cfg) # type: ignore [arg-type]
@@ -142,6 +161,7 @@ def quantize_model(
142161
batch_size,
143162
calib_size,
144163
auto_quantize_bits=None,
164+
auto_quantize_method="gradient",
145165
data="cnn_dailymail",
146166
test_generated=True,
147167
compress=False,
@@ -156,6 +176,7 @@ def quantize_model(
156176
batch_size: the calibration batch size for each calibration inference run.
157177
calib_size: the total calibration dataset size.
158178
auto_quantize_bits: The effective bits constraint for auto_quantize.
179+
auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div').
159180
data: the name of the calibration dataset.
160181
test_generated: If ``True``, test the generated text before and after quantization.
161182
compress: If ``True``, compress the model after quantization.
@@ -180,21 +201,30 @@ def quantize_model(
180201
batch_size = get_max_batch_size(net)
181202
print(f"Update calib batch {batch_size}")
182203

204+
# Labels are only needed for gradient-based auto_quantize
205+
include_labels = auto_quantize_bits is not None and auto_quantize_method == "gradient"
206+
183207
calib_dataloader = get_dataset_dataloader(
184208
dataset_name=data,
185209
tokenizer=tokenizer,
186210
batch_size=batch_size,
187211
num_samples=calib_size,
188212
device=device,
189-
include_labels=auto_quantize_bits is not None,
213+
include_labels=include_labels,
190214
)
191215

192216
if test_generated:
193217
input_str = tokenizer.decode(next(iter(calib_dataloader))["input_ids"][0])
194218
generated_str_before_ptq = model.run(input_str)
195219

196220
_quantize_model_with_dataset(
197-
model, quant_cfg, calib_dataloader, auto_quantize_bits, batch_size, compress
221+
model,
222+
quant_cfg,
223+
calib_dataloader,
224+
auto_quantize_bits,
225+
auto_quantize_method,
226+
batch_size,
227+
compress,
198228
)
199229

200230
if test_generated:

0 commit comments

Comments
 (0)