Skip to content

Commit c96d919

Browse files
Asma Kuriparambil ThekkumpaterealAsma
authored andcommitted
[3/N] Added autoquantize search state save/restore support
Some improvements for KLDiv Signed-off-by: realAsma <[email protected]>
1 parent c95e550 commit c96d919

File tree

10 files changed

+373
-88
lines changed

10 files changed

+373
-88
lines changed

examples/llm_eval/gen_model_answer.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,11 @@ def get_model_answers(
201201
tokenizer,
202202
args.calib_batch_size,
203203
args.calib_size,
204-
args.auto_quantize_bits,
205204
test_generated=False,
205+
auto_quantize_bits=args.auto_quantize_bits,
206+
auto_quantize_method=args.auto_quantize_method,
207+
auto_quantize_score_size=args.auto_quantize_score_size,
208+
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
206209
)
207210

208211
for question in tqdm(questions):
@@ -450,6 +453,35 @@ def reorg_answer_file(answer_file):
450453
"regular quantization without auto_quantize search will be applied."
451454
),
452455
)
456+
parser.add_argument(
457+
"--auto_quantize_method",
458+
type=str,
459+
default="gradient",
460+
choices=["gradient", "kl_div"],
461+
help=(
462+
"Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
463+
"(requires labels in dataset). 'kl_div' uses KL divergence between original and "
464+
"quantized model outputs (no labels required). Default: 'gradient'"
465+
),
466+
)
467+
parser.add_argument(
468+
"--auto_quantize_score_size",
469+
type=int,
470+
default=128,
471+
help=(
472+
"Number of samples to use for scoring in auto_quantize. Default: 128. "
473+
"Higher values improve accuracy but increase time."
474+
),
475+
)
476+
parser.add_argument(
477+
"--auto_quantize_checkpoint",
478+
type=str,
479+
default=None,
480+
help=(
481+
"Path to checkpoint file for saving/restoring auto_quantize search state "
482+
"(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
483+
),
484+
)
453485
parser.add_argument(
454486
"--trust_remote_code",
455487
help="Set trust_remote_code for Huggingface models and tokenizers",

examples/llm_eval/lm_eval_hf.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
5454
quant_cfg = arg_dict.pop("quant_cfg", None)
5555
auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None)
5656
auto_quantize_method = arg_dict.pop("auto_quantize_method", "gradient")
57+
auto_quantize_score_size = arg_dict.pop("auto_quantize_score_size", 128)
58+
auto_quantize_checkpoint = arg_dict.pop("auto_quantize_checkpoint", None)
5759
calib_batch_size = arg_dict.pop("calib_batch_size", None)
5860
calib_size = arg_dict.pop("calib_size", 512)
5961
compress = arg_dict.pop("compress", False)
@@ -83,8 +85,10 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
8385
calib_size=calib_size,
8486
auto_quantize_bits=auto_quantize_bits,
8587
auto_quantize_method=auto_quantize_method,
88+
auto_quantize_score_size=auto_quantize_score_size,
8689
test_generated=False,
8790
compress=compress,
91+
auto_quantize_checkpoint=auto_quantize_checkpoint,
8892
)
8993

9094
return model_obj
@@ -103,6 +107,12 @@ def setup_parser_with_modelopt_args():
103107
"comma-separated list of quantization quantization formats that will be searched by `auto_quantize`"
104108
),
105109
)
110+
parser.add_argument(
111+
"--calib_batch_size", type=int, help="Batch size for quantization calibration"
112+
)
113+
parser.add_argument(
114+
"--calib_size", type=int, help="Calibration size for quantization", default=512
115+
)
106116
parser.add_argument(
107117
"--auto_quantize_bits",
108118
type=float,
@@ -123,10 +133,18 @@ def setup_parser_with_modelopt_args():
123133
),
124134
)
125135
parser.add_argument(
126-
"--calib_batch_size", type=int, help="Batch size for quantization calibration"
136+
"--auto_quantize_score_size",
137+
type=int,
138+
default=128,
139+
help=(
140+
"Number of samples to use for scoring in auto_quantize. Default: 128. "
141+
"Higher values improve accuracy but increase time."
142+
),
127143
)
128144
parser.add_argument(
129-
"--calib_size", type=int, help="Calibration size for quantization", default=512
145+
"--auto_quantize_checkpoint",
146+
type=str,
147+
help=("Path to checkpoint file for saving/restoring auto_quantize search state. "),
130148
)
131149
parser.add_argument(
132150
"--compress",
@@ -153,6 +171,8 @@ def setup_parser_with_modelopt_args():
153171
"quant_cfg": args.quant_cfg,
154172
"auto_quantize_bits": args.auto_quantize_bits,
155173
"auto_quantize_method": args.auto_quantize_method,
174+
"auto_quantize_score_size": args.auto_quantize_score_size,
175+
"auto_quantize_checkpoint": args.auto_quantize_checkpoint,
156176
"calib_batch_size": args.calib_batch_size,
157177
"calib_size": args.calib_size,
158178
"compress": args.compress,

examples/llm_eval/mmlu.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,12 @@ def main(
224224
ntrain: int = 5,
225225
quant_cfg: str | None = None,
226226
auto_quantize_bits: float | None = None,
227-
auto_quantize_method: str = "gradient",
228227
batch_size: int = 0,
229228
calib_size: int = 512,
230229
dtype: str = "bfloat16",
230+
auto_quantize_method: str = "gradient",
231+
auto_quantize_score_size: int = 128,
232+
auto_quantize_checkpoint: str | None = None,
231233
**kwargs,
232234
):
233235
random.seed(RAND_SEED)
@@ -283,6 +285,8 @@ def main(
283285
calib_size=calib_size,
284286
auto_quantize_bits=auto_quantize_bits,
285287
auto_quantize_method=auto_quantize_method,
288+
auto_quantize_score_size=auto_quantize_score_size,
289+
auto_quantize_checkpoint=auto_quantize_checkpoint,
286290
)
287291

288292
for subject in tqdm(subjects):

examples/llm_eval/quantization_utils.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,10 @@ def _quantize_model_with_dataset(
6767
calib_dataset,
6868
auto_quantize_bits=None,
6969
auto_quantize_method="gradient",
70+
auto_quantize_score_size=128,
7071
batch_size=1,
7172
compress=False,
73+
auto_quantize_checkpoint=None,
7274
):
7375
if hasattr(lm, "gpt2"):
7476
net = lm.gpt2
@@ -112,11 +114,11 @@ def forward_step(model, batch):
112114
forward_step=forward_step,
113115
loss_func=loss_func,
114116
num_calib_steps=len(calib_dataset),
115-
num_score_steps=min(
116-
len(calib_dataset), 128 // batch_size
117-
), # Limit the number of score steps to avoid long calibration time
117+
num_score_steps=min(len(calib_dataset), max(auto_quantize_score_size // batch_size, 1)),
118118
verbose=True,
119119
method=auto_quantize_method,
120+
# disabled_layers=["*lm_head*", "*mlp.gate.*"],
121+
checkpoint=auto_quantize_checkpoint,
120122
)
121123
else:
122124
mtq_cfg = CUSTOM_CONFIG.get(quant_cfg) # type: ignore [arg-type]
@@ -160,11 +162,13 @@ def quantize_model(
160162
tokenizer,
161163
batch_size,
162164
calib_size,
163-
auto_quantize_bits=None,
164-
auto_quantize_method="gradient",
165165
data="cnn_dailymail",
166166
test_generated=True,
167167
compress=False,
168+
auto_quantize_bits=None,
169+
auto_quantize_method="gradient",
170+
auto_quantize_score_size=128,
171+
auto_quantize_checkpoint=None,
168172
):
169173
"""Quantizes the model with the provided calibration dataset.
170174
@@ -175,11 +179,14 @@ def quantize_model(
175179
tokenizer: the tokenizer.
176180
batch_size: the calibration batch size for each calibration inference run.
177181
calib_size: the total calibration dataset size.
178-
auto_quantize_bits: The effective bits constraint for auto_quantize.
179-
auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div').
180182
data: the name of the calibration dataset.
181183
test_generated: If ``True``, test the generated text before and after quantization.
182184
compress: If ``True``, compress the model after quantization.
185+
auto_quantize_bits: The effective bits constraint for auto_quantize.
186+
auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div').
187+
auto_quantize_score_size: Number of samples to use for scoring in auto_quantize. Default: 128.
188+
auto_quantize_checkpoint: Path to checkpoint file for saving/restoring auto_quantize search state
189+
(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified.
183190
"""
184191
if "AWQ" in quant_cfg:
185192
print(
@@ -191,8 +198,10 @@ def quantize_model(
191198
if hasattr(model, "model"):
192199
device = model.model.device
193200

201+
is_gradient_based = auto_quantize_bits is not None and auto_quantize_method == "gradient"
202+
194203
if batch_size == 0:
195-
if auto_quantize_bits is not None or torch.distributed.is_initialized():
204+
if is_gradient_based or torch.distributed.is_initialized():
196205
raise ValueError("We dont support automatic batch size inference for this case.")
197206

198207
net = model.gpt2 if hasattr(model, "gpt2") else model.model
@@ -201,16 +210,13 @@ def quantize_model(
201210
batch_size = get_max_batch_size(net)
202211
print(f"Update calib batch {batch_size}")
203212

204-
# Labels are only needed for gradient-based auto_quantize
205-
include_labels = auto_quantize_bits is not None and auto_quantize_method == "gradient"
206-
207213
calib_dataloader = get_dataset_dataloader(
208214
dataset_name=data,
209215
tokenizer=tokenizer,
210216
batch_size=batch_size,
211217
num_samples=calib_size,
212218
device=device,
213-
include_labels=include_labels,
219+
include_labels=is_gradient_based,
214220
)
215221

216222
if test_generated:
@@ -223,8 +229,10 @@ def quantize_model(
223229
calib_dataloader,
224230
auto_quantize_bits,
225231
auto_quantize_method,
232+
auto_quantize_score_size,
226233
batch_size,
227234
compress,
235+
auto_quantize_checkpoint,
228236
)
229237

230238
if test_generated:

examples/llm_ptq/hf_ptq.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,15 @@
9595

9696

9797
def auto_quantize(
98-
model, qformat, auto_quantize_bits, calib_dataloader, calibrate_loop, batch_size=1
98+
model,
99+
qformat,
100+
calib_dataloader,
101+
calibrate_loop,
102+
auto_quantize_bits,
103+
batch_size=1,
104+
auto_quantize_method="gradient",
105+
auto_quantize_score_size=128,
106+
auto_quantize_checkpoint=None,
99107
):
100108
qformat_list = qformat.split(",")
101109
assert qformat_list, "No quantization formats provided"
@@ -122,18 +130,33 @@ def loss_func(output, data):
122130
# which contains the loss attribute.
123131
return output.loss
124132

133+
if auto_quantize_method == "gradient":
134+
# For gradient-based method, return full output with loss
135+
def forward_step(model, batch):
136+
return model(**batch)
137+
elif auto_quantize_method == "kl_div":
138+
# For KL divergence method, return only logits
139+
def forward_step(model, batch):
140+
return model(**batch).logits
141+
else:
142+
raise ValueError(
143+
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
144+
)
145+
125146
model, _ = mtq.auto_quantize(
126147
model,
127148
constraints={"effective_bits": auto_quantize_bits},
128149
data_loader=calib_dataloader,
129-
forward_step=lambda model, batch: model(**batch),
130-
loss_func=loss_func,
150+
forward_step=forward_step,
151+
loss_func=loss_func, # Only used for gradient-based method
131152
# TRTLLM only support one quantization format or None (do not quantize, internally supported)
132153
quantization_formats=[QUANT_CFG_CHOICES[format] for format in qformat_list],
133154
num_calib_steps=len(calib_dataloader),
134-
num_score_steps=len(calib_dataloader),
155+
num_score_steps=min(len(calib_dataloader), max(auto_quantize_score_size // batch_size, 1)),
135156
verbose=True,
136157
disabled_layers=["*lm_head*"],
158+
method=auto_quantize_method,
159+
checkpoint=auto_quantize_checkpoint,
137160
)
138161

139162
# We need to explicitly calibrate for kv cache quantization
@@ -191,10 +214,13 @@ def quantize_model(model, quant_cfg, args, calib_dataloader=None, calibration_on
191214
model = auto_quantize(
192215
model,
193216
args.qformat,
194-
args.auto_quantize_bits,
195217
calib_dataloader,
196218
calibrate_loop,
219+
args.auto_quantize_bits,
197220
args.batch_size,
221+
args.auto_quantize_method,
222+
args.auto_quantize_score_size,
223+
args.auto_quantize_checkpoint,
198224
)
199225
elif calibration_only:
200226
model = mtq.calibrate(model, quant_cfg["algorithm"], forward_loop=calibrate_loop)
@@ -444,13 +470,17 @@ def main(args):
444470
assert tokenizer is not None and isinstance(
445471
tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
446472
), "The PreTrainedTokenizer must be set"
473+
# Labels are only needed for gradient-based auto_quantize
474+
include_labels = (
475+
args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient"
476+
)
447477
calib_dataloader = get_dataset_dataloader(
448478
dataset_name=args.dataset,
449479
tokenizer=tokenizer,
450480
batch_size=args.batch_size,
451481
num_samples=args.calib_size,
452482
device=device,
453-
include_labels=args.auto_quantize_bits is not None,
483+
include_labels=include_labels,
454484
)
455485

456486
quant_cfg = build_quant_cfg(
@@ -803,6 +833,35 @@ def output_decode(generated_ids, input_shape):
803833
default=None,
804834
type=str,
805835
)
836+
parser.add_argument(
837+
"--auto_quantize_method",
838+
type=str,
839+
default="gradient",
840+
choices=["gradient", "kl_div"],
841+
help=(
842+
"Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
843+
"(requires labels in dataset). 'kl_div' uses KL divergence between original and "
844+
"quantized model outputs (no labels required). Default: 'gradient'"
845+
),
846+
)
847+
parser.add_argument(
848+
"--auto_quantize_score_size",
849+
type=int,
850+
default=128,
851+
help=(
852+
"Number of samples to use for scoring in auto_quantize. Default: 128. "
853+
"Higher values improve accuracy but increase time."
854+
),
855+
)
856+
parser.add_argument(
857+
"--auto_quantize_checkpoint",
858+
type=str,
859+
default=None,
860+
help=(
861+
"Path to checkpoint file for saving/restoring auto_quantize search state "
862+
"(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
863+
),
864+
)
806865

807866
args = parser.parse_args()
808867

examples/llm_ptq/scripts/huggingface_example.sh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,28 @@ fi
9393
if [ -n "$AUTO_QUANTIZE_BITS" ]; then
9494
PTQ_ARGS+=" --auto_quantize_bits=$AUTO_QUANTIZE_BITS "
9595
fi
96+
97+
if [ -n "$AUTO_QUANTIZE_METHOD" ]; then
98+
PTQ_ARGS+=" --auto_quantize_method=$AUTO_QUANTIZE_METHOD "
99+
fi
100+
101+
if [ -n "$AUTO_QUANTIZE_SCORE_SIZE" ]; then
102+
PTQ_ARGS+=" --auto_quantize_score_size=$AUTO_QUANTIZE_SCORE_SIZE "
103+
fi
104+
105+
# Automatically generate auto_quantize checkpoint path if not provided
106+
if [ -n "$AUTO_QUANTIZE_BITS" ] && [ -z "$AUTO_QUANTIZE_CHECKPOINT" ]; then
107+
# Create a descriptive checkpoint name based on model and quantization settings
108+
AQ_METHOD=${AUTO_QUANTIZE_METHOD:-gradient}
109+
AUTO_QUANTIZE_CHECKPOINT="${ROOT_SAVE_PATH}/auto_quantize_checkpoints/${MODEL_NAME}_${AQ_METHOD}.pth"
110+
mkdir -p $(dirname $AUTO_QUANTIZE_CHECKPOINT)
111+
echo "Auto-generated auto_quantize checkpoint path: $AUTO_QUANTIZE_CHECKPOINT"
112+
fi
113+
114+
if [ -n "$AUTO_QUANTIZE_BITS" ]; then
115+
PTQ_ARGS+=" --auto_quantize_checkpoint=$AUTO_QUANTIZE_CHECKPOINT "
116+
fi
117+
96118
if [ -n "$CALIB_DATASET" ]; then
97119
PTQ_ARGS+=" --dataset=$CALIB_DATASET "
98120
fi

0 commit comments

Comments
 (0)