Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Model Optimizer Changelog (Linux)
- Add MoE (e.g. Qwen3-30B-A3B, gpt-oss-20b) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``).
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
- Add KL Divergence loss based auto_quantize method. See `auto_quantize API docs <https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.quantization.model_quant.html#modelopt.torch.quantization.model_quant.auto_quantize>`_ for more details.
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
- Add support for PyTorch Geometric quantization.

Expand Down
14 changes: 14 additions & 0 deletions examples/llm_eval/lm_eval_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |

quant_cfg = arg_dict.pop("quant_cfg", None)
auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None)
auto_quantize_method = arg_dict.pop("auto_quantize_method", "gradient")
calib_batch_size = arg_dict.pop("calib_batch_size", None)
calib_size = arg_dict.pop("calib_size", 512)
compress = arg_dict.pop("compress", False)
Expand Down Expand Up @@ -81,6 +82,7 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
batch_size=calib_batch_size,
calib_size=calib_size,
auto_quantize_bits=auto_quantize_bits,
auto_quantize_method=auto_quantize_method,
test_generated=False,
compress=compress,
)
Expand Down Expand Up @@ -109,6 +111,17 @@ def setup_parser_with_modelopt_args():
"regular quantization will be applied."
),
)
parser.add_argument(
"--auto_quantize_method",
type=str,
default="gradient",
choices=["gradient", "kl_div"],
help=(
"Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
"(requires labels in dataset). 'kl_div' uses KL divergence between original and "
"quantized model outputs (no labels required). Default: 'gradient'"
),
)
parser.add_argument(
"--calib_batch_size", type=int, help="Batch size for quantization calibration"
)
Expand Down Expand Up @@ -139,6 +152,7 @@ def setup_parser_with_modelopt_args():
{
"quant_cfg": args.quant_cfg,
"auto_quantize_bits": args.auto_quantize_bits,
"auto_quantize_method": args.auto_quantize_method,
"calib_batch_size": args.calib_batch_size,
"calib_size": args.calib_size,
"compress": args.compress,
Expand Down
2 changes: 2 additions & 0 deletions examples/llm_eval/mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def main(
ntrain: int = 5,
quant_cfg: str | None = None,
auto_quantize_bits: float | None = None,
auto_quantize_method: str = "gradient",
batch_size: int = 0,
calib_size: int = 512,
dtype: str = "bfloat16",
Expand Down Expand Up @@ -281,6 +282,7 @@ def main(
batch_size=batch_size,
calib_size=calib_size,
auto_quantize_bits=auto_quantize_bits,
auto_quantize_method=auto_quantize_method,
)

for subject in tqdm(subjects):
Expand Down
44 changes: 37 additions & 7 deletions examples/llm_eval/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _quantize_model_with_dataset(
quant_cfg: str | list[str],
calib_dataset,
auto_quantize_bits=None,
auto_quantize_method="gradient",
batch_size=1,
compress=False,
):
Expand All @@ -81,23 +82,41 @@ def _quantize_model_with_dataset(
getattr(mtq, quant_fmt) for quant_fmt in quant_cfg if quant_fmt != "NONE"
]

def loss_func(output, data):
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
# which contains the loss attribute.
return output.loss
# Configure forward_step and loss_func based on method
if auto_quantize_method == "gradient":
# For gradient-based method, return full output with loss
def forward_step(model, batch):
return model(**batch)

def loss_func(output, data):
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
# which contains the loss attribute.
return output.loss
elif auto_quantize_method == "kl_div":
# For KL divergence method, return only logits
def forward_step(model, batch):
return model(**batch).logits

loss_func = None # KL divergence doesn't need a custom loss function
else:
raise ValueError(
f"Invalid auto_quantize_method: {auto_quantize_method}. "
"Must be 'gradient' or 'kl_div'"
)

net, _ = mtq.auto_quantize(
net,
constraints={"effective_bits": auto_quantize_bits},
quantization_formats=quant_cfg_for_search,
data_loader=calib_dataset,
forward_step=lambda model, batch: model(**batch),
forward_step=forward_step,
loss_func=loss_func,
num_calib_steps=len(calib_dataset),
num_score_steps=min(
len(calib_dataset), 128 // batch_size
), # Limit the number of score steps to avoid long calibration time
verbose=True,
method=auto_quantize_method,
)
else:
mtq_cfg = CUSTOM_CONFIG.get(quant_cfg) # type: ignore [arg-type]
Expand Down Expand Up @@ -142,6 +161,7 @@ def quantize_model(
batch_size,
calib_size,
auto_quantize_bits=None,
auto_quantize_method="gradient",
data="cnn_dailymail",
test_generated=True,
compress=False,
Expand All @@ -156,6 +176,7 @@ def quantize_model(
batch_size: the calibration batch size for each calibration inference run.
calib_size: the total calibration dataset size.
auto_quantize_bits: The effective bits constraint for auto_quantize.
auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div').
data: the name of the calibration dataset.
test_generated: If ``True``, test the generated text before and after quantization.
compress: If ``True``, compress the model after quantization.
Expand All @@ -180,21 +201,30 @@ def quantize_model(
batch_size = get_max_batch_size(net)
print(f"Update calib batch {batch_size}")

# Labels are only needed for gradient-based auto_quantize
include_labels = auto_quantize_bits is not None and auto_quantize_method == "gradient"

calib_dataloader = get_dataset_dataloader(
dataset_name=data,
tokenizer=tokenizer,
batch_size=batch_size,
num_samples=calib_size,
device=device,
include_labels=auto_quantize_bits is not None,
include_labels=include_labels,
)

if test_generated:
input_str = tokenizer.decode(next(iter(calib_dataloader))["input_ids"][0])
generated_str_before_ptq = model.run(input_str)

_quantize_model_with_dataset(
model, quant_cfg, calib_dataloader, auto_quantize_bits, batch_size, compress
model,
quant_cfg,
calib_dataloader,
auto_quantize_bits,
auto_quantize_method,
batch_size,
compress,
)

if test_generated:
Expand Down
Loading