Skip to content

Commit b941367

Browse files
Asma Kuriparambil ThekkumpaterealAsma
authored andcommitted
[2/N] Added KDLoss based AutoQuantize
Signed-off-by: Asma Kuriparambil Thekkumpate <[email protected]> minor Signed-off-by: Asma Kuriparambil Thekkumpate <[email protected]>
1 parent 25c41f7 commit b941367

File tree

7 files changed

+272
-20
lines changed

7 files changed

+272
-20
lines changed

examples/llm_eval/lm_eval_hf.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
5353

5454
quant_cfg = arg_dict.pop("quant_cfg", None)
5555
auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None)
56+
auto_quantize_method = arg_dict.pop("auto_quantize_method", "gradient")
5657
calib_batch_size = arg_dict.pop("calib_batch_size", None)
5758
calib_size = arg_dict.pop("calib_size", 512)
5859
compress = arg_dict.pop("compress", False)
@@ -81,6 +82,7 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
8182
batch_size=calib_batch_size,
8283
calib_size=calib_size,
8384
auto_quantize_bits=auto_quantize_bits,
85+
auto_quantize_method=auto_quantize_method,
8486
test_generated=False,
8587
compress=compress,
8688
)
@@ -109,6 +111,17 @@ def setup_parser_with_modelopt_args():
109111
"regular quantization will be applied."
110112
),
111113
)
114+
parser.add_argument(
115+
"--auto_quantize_method",
116+
type=str,
117+
default="gradient",
118+
choices=["gradient", "kl_div"],
119+
help=(
120+
"Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
121+
"(requires labels in dataset). 'kl_div' uses KL divergence between original and "
122+
"quantized model outputs (no labels required). Default: 'gradient'"
123+
),
124+
)
112125
parser.add_argument(
113126
"--calib_batch_size", type=int, help="Batch size for quantization calibration"
114127
)
@@ -139,6 +152,7 @@ def setup_parser_with_modelopt_args():
139152
{
140153
"quant_cfg": args.quant_cfg,
141154
"auto_quantize_bits": args.auto_quantize_bits,
155+
"auto_quantize_method": args.auto_quantize_method,
142156
"calib_batch_size": args.calib_batch_size,
143157
"calib_size": args.calib_size,
144158
"compress": args.compress,

examples/llm_eval/mmlu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def main(
224224
ntrain: int = 5,
225225
quant_cfg: str | None = None,
226226
auto_quantize_bits: float | None = None,
227+
auto_quantize_method: str = "gradient",
227228
batch_size: int = 0,
228229
calib_size: int = 512,
229230
dtype: str = "bfloat16",
@@ -281,6 +282,7 @@ def main(
281282
batch_size=batch_size,
282283
calib_size=calib_size,
283284
auto_quantize_bits=auto_quantize_bits,
285+
auto_quantize_method=auto_quantize_method,
284286
)
285287

286288
for subject in tqdm(subjects):

examples/llm_eval/quantization_utils.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def _quantize_model_with_dataset(
6666
quant_cfg: str | list[str],
6767
calib_dataset,
6868
auto_quantize_bits=None,
69+
auto_quantize_method="gradient",
6970
batch_size=1,
7071
compress=False,
7172
):
@@ -81,23 +82,41 @@ def _quantize_model_with_dataset(
8182
getattr(mtq, quant_fmt) for quant_fmt in quant_cfg if quant_fmt != "NONE"
8283
]
8384

84-
def loss_func(output, data):
85-
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
86-
# which contains the loss attribute.
87-
return output.loss
85+
# Configure forward_step and loss_func based on method
86+
if auto_quantize_method == "gradient":
87+
# For gradient-based method, return full output with loss
88+
def forward_step(model, batch):
89+
return model(**batch)
90+
91+
def loss_func(output, data):
92+
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
93+
# which contains the loss attribute.
94+
return output.loss
95+
elif auto_quantize_method == "kl_div":
96+
# For KL divergence method, return only logits
97+
def forward_step(model, batch):
98+
return model(**batch).logits
99+
100+
loss_func = None # KL divergence doesn't need a custom loss function
101+
else:
102+
raise ValueError(
103+
f"Invalid auto_quantize_method: {auto_quantize_method}. "
104+
"Must be 'gradient' or 'kl_div'"
105+
)
88106

89107
net, _ = mtq.auto_quantize(
90108
net,
91109
constraints={"effective_bits": auto_quantize_bits},
92110
quantization_formats=quant_cfg_for_search,
93111
data_loader=calib_dataset,
94-
forward_step=lambda model, batch: model(**batch),
112+
forward_step=forward_step,
95113
loss_func=loss_func,
96114
num_calib_steps=len(calib_dataset),
97115
num_score_steps=min(
98116
len(calib_dataset), 128 // batch_size
99117
), # Limit the number of score steps to avoid long calibration time
100118
verbose=True,
119+
method=auto_quantize_method,
101120
)
102121
else:
103122
mtq_cfg = CUSTOM_CONFIG.get(quant_cfg) # type: ignore [arg-type]
@@ -142,6 +161,7 @@ def quantize_model(
142161
batch_size,
143162
calib_size,
144163
auto_quantize_bits=None,
164+
auto_quantize_method="gradient",
145165
data="cnn_dailymail",
146166
test_generated=True,
147167
compress=False,
@@ -156,6 +176,7 @@ def quantize_model(
156176
batch_size: the calibration batch size for each calibration inference run.
157177
calib_size: the total calibration dataset size.
158178
auto_quantize_bits: The effective bits constraint for auto_quantize.
179+
auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div').
159180
data: the name of the calibration dataset.
160181
test_generated: If ``True``, test the generated text before and after quantization.
161182
compress: If ``True``, compress the model after quantization.
@@ -180,21 +201,30 @@ def quantize_model(
180201
batch_size = get_max_batch_size(net)
181202
print(f"Update calib batch {batch_size}")
182203

204+
# Labels are only needed for gradient-based auto_quantize
205+
include_labels = auto_quantize_bits is not None and auto_quantize_method == "gradient"
206+
183207
calib_dataloader = get_dataset_dataloader(
184208
dataset_name=data,
185209
tokenizer=tokenizer,
186210
batch_size=batch_size,
187211
num_samples=calib_size,
188212
device=device,
189-
include_labels=auto_quantize_bits is not None,
213+
include_labels=include_labels,
190214
)
191215

192216
if test_generated:
193217
input_str = tokenizer.decode(next(iter(calib_dataloader))["input_ids"][0])
194218
generated_str_before_ptq = model.run(input_str)
195219

196220
_quantize_model_with_dataset(
197-
model, quant_cfg, calib_dataloader, auto_quantize_bits, batch_size, compress
221+
model,
222+
quant_cfg,
223+
calib_dataloader,
224+
auto_quantize_bits,
225+
auto_quantize_method,
226+
batch_size,
227+
compress,
198228
)
199229

200230
if test_generated:

modelopt/torch/quantization/algorithms.py

Lines changed: 170 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import regex as re
2929
import torch
3030
import torch.nn as nn
31+
import torch.nn.functional as F
3132
from tqdm import tqdm
3233

3334
from modelopt.torch.opt.conversion import ModeloptStateManager
@@ -943,8 +944,175 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
943944
return best_recipes, is_satisfied
944945

945946

946-
class AutoQuantizeLossSearcher(_AutoQuantizeBaseSearcher):
947-
"""A searcher for AutoQuantize algorithm that uses loss based score estimation."""
947+
@torch.compile
948+
def _get_kl_div_loss(logits_unquant: torch.Tensor, logits_quant: torch.Tensor) -> torch.Tensor:
949+
# TODO: Support TensorParallel
950+
prob_unquant = F.softmax(logits_unquant, dim=-1)
951+
log_prob_quant = F.log_softmax(logits_quant, dim=-1)
952+
return F.kl_div(log_prob_quant, prob_unquant, reduction="sum", log_target=False)
953+
954+
955+
class AutoQuantizeKLDivSearcher(_AutoQuantizeBaseSearcher):
956+
"""A searcher for AutoQuantize algorithm that uses KL-Divergence loss based score estimation."""
957+
958+
score_module_rules: list[str | Callable] = [lambda name: ""]
959+
960+
@property
961+
def default_search_config(self):
962+
"""Get the default config for the searcher."""
963+
config = super().default_search_config
964+
config.update(
965+
{
966+
"forward_step": None,
967+
}
968+
)
969+
return config
970+
971+
def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
972+
"""Sanitize the search config dict."""
973+
config = config or {}
974+
for ignored_key in ["score_func", "loss_func", "forward_backward_step"]:
975+
if ignored_key in config:
976+
warnings.warn(
977+
f"`{ignored_key}` is ignored for KL-Divergence loss based `auto_quantize`."
978+
)
979+
config.pop(ignored_key)
980+
config = super().sanitize_search_config(config)
981+
assert config["forward_step"] is not None, (
982+
"`forward_step` must be provided for KL-Divergence loss based `auto_quantize`. "
983+
"`forward_step(model, data)` should return model logits."
984+
)
985+
return config
986+
987+
@torch.no_grad()
988+
def estimate_sensitivity_scores(self):
989+
"""Estimate the sensitivity scores for the model.
990+
991+
Higher score means more sensitive to quantization.
992+
"""
993+
# Check if tensor parallelism is being used
994+
for name, module in self.model.named_modules():
995+
if hasattr(module, "parallel_state"):
996+
if hasattr(module.parallel_state, "tensor_parallel_group"):
997+
if module.parallel_state.tensor_parallel_group.is_initialized():
998+
warnings.warn(
999+
"Tensor Parallel is not supported for KL-Divergence based auto_quantize. "
1000+
)
1001+
break
1002+
1003+
def set_to_unquantized():
1004+
for name, hparam in named_hparams(self.model, unique=True):
1005+
if not isinstance(hparam, QuantRecipeHparam):
1006+
continue
1007+
if hparam.is_configurable:
1008+
hparam.active = QuantRecipe(quant_cfg=None)
1009+
1010+
self.model.eval()
1011+
num_iters = self.config["num_score_steps"]
1012+
for _, data in tqdm(
1013+
zip(range(num_iters), self.config["data_loader"]),
1014+
desc="Estimating KLDivergence loss",
1015+
total=num_iters,
1016+
):
1017+
set_to_unquantized()
1018+
logits_unquant = self.config["forward_step"](self.model, data)
1019+
1020+
for name, hparam in named_hparams(self.model, configurable=True):
1021+
if not isinstance(hparam, QuantRecipeHparam):
1022+
continue
1023+
for recipe in hparam.choices:
1024+
if recipe == QuantRecipe(quant_cfg=None):
1025+
continue
1026+
hparam.active = recipe
1027+
logits_quant = self.config["forward_step"](self.model, data)
1028+
score = _get_kl_div_loss(logits_unquant, logits_quant)
1029+
hparam._importance_dict[recipe][hparam.score_modules[0]] = score
1030+
hparam.active = QuantRecipe(quant_cfg=None)
1031+
1032+
def run_search_with_stats(self, max_weight_size, verbose=False):
1033+
"""Run threshold-based binary search for KLDivergence loss based auto_quantize.
1034+
1035+
We use binary search to minimize the max(per-layer score) while meeting the constraint.
1036+
"""
1037+
# Collect all sensitivity scores to determine initial threshold bounds
1038+
all_scores = [
1039+
score for name in self.candidate_stats for score in self.candidate_stats[name]["scores"]
1040+
]
1041+
1042+
if not all_scores:
1043+
warnings.warn("No scores available for threshold-based search!")
1044+
is_satisfied = False
1045+
return {}, is_satisfied
1046+
1047+
# Initialize binary search bounds
1048+
min_score = min(all_scores)
1049+
max_score = max(all_scores)
1050+
threshold = (min_score + max_score) / 2.0
1051+
lower_bound = min_score
1052+
upper_bound = max_score
1053+
1054+
# Run for fixed number of iterations
1055+
max_iterations = 100
1056+
1057+
if verbose:
1058+
print_rank_0("AutoQuantize: Starting threshold-based binary search")
1059+
print_rank_0(f" Score range: [{min_score:.6e}, {max_score:.6e}]")
1060+
print_rank_0(f" Target weight size: {max_weight_size:.2f}")
1061+
1062+
for iteration in range(max_iterations):
1063+
# Select recipes based on current threshold
1064+
best_recipes = {}
1065+
total_weight_size = 0.0
1066+
1067+
for name in self.candidate_stats:
1068+
formats = self.candidate_stats[name]["formats"]
1069+
scores = self.candidate_stats[name]["scores"]
1070+
costs = self.candidate_stats[name]["costs"]
1071+
1072+
selected_idx = 0
1073+
for idx in range(len(formats)):
1074+
if scores[idx] <= threshold:
1075+
selected_idx = idx
1076+
break
1077+
1078+
best_recipes[name] = {
1079+
"format": formats[selected_idx],
1080+
"costs": costs[selected_idx],
1081+
"scores": scores[selected_idx],
1082+
}
1083+
total_weight_size += costs[selected_idx]
1084+
1085+
# Check if we meet the constraint
1086+
meets_constraint = total_weight_size <= max_weight_size
1087+
1088+
if verbose:
1089+
print_rank_0(
1090+
f" Iteration {iteration + 1}: threshold={threshold:.6e}, "
1091+
f"weight_size={total_weight_size:.2f}, "
1092+
f"meets_constraint={meets_constraint}"
1093+
)
1094+
1095+
# Update binary search bounds
1096+
if meets_constraint:
1097+
upper_bound = threshold # Threshold was too aggressive, relax it
1098+
else:
1099+
lower_bound = threshold # Threshold was too lax, tighten it
1100+
1101+
# Update threshold for next iteration
1102+
threshold = (lower_bound + upper_bound) / 2.0
1103+
1104+
# Final check if constraint is satisfied
1105+
is_satisfied = total_weight_size <= max_weight_size
1106+
1107+
if verbose:
1108+
print_rank_0(
1109+
f"AutoQuantize: Search complete. "
1110+
f"Final weight size: {total_weight_size:.2f} "
1111+
f"(target: {max_weight_size:.2f}), "
1112+
f"constraint satisfied: {is_satisfied}"
1113+
)
1114+
1115+
return best_recipes, is_satisfied
9481116

9491117

9501118
# Backward compatibility alias (defaults to gradient-based searcher)

0 commit comments

Comments
 (0)