Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions .github/scripts/torchao_model_releases/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ By default, we release FP8, INT4, INT8-INT4 checkpoints, with model card pre-fil

Examples:
```
# Note: first login with `huggingface-cli login`, the quantized model will be uploaded to
# the logged in user
# Note: first login with `hf auth login`, the quantized model will be uploaded to the logged in user

# release with default quant options (FP8, INT4, INT8-INT4)
./release.sh --model_id Qwen/Qwen3-8B
Expand All @@ -20,8 +19,11 @@ Examples:

Note: for initial release, please include `--populate_model_card_template` to populate model card template.

### SMOOTHQUANT-INT8-INT8
[SmoothQuant](https://arxiv.org/abs/2211.10438) smooths activation outliers by migrating quantization difficulty from activations to weights through a mathematically equivalent per-channel scaling transformation. That means SmoothQuant observes activation distribution before applying quantization.

### AWQ-INT4
[AWQ](https://arxiv.org/abs/2306.00978) is a technique to improve accuracy for weight only quantization. It improves accuracy by preserving "salient" weight channels that has high impact on the accuracy of output, through multiplying the weight channel by a scale, and do the reverse for the correspnoding activation, since activation is not quantized, there is no additional loss from activation, while the quantization loss from weight can be reduced.
Similar to SmoothQuant, [AWQ](https://arxiv.org/abs/2306.00978) improves accuracy by preserving "salient" weight channels that has high impact on the accuracy of output. The notable point is that AWQ uses activation distribution to find salient weights, not weight distribution, multiplying the weight channel by a scale, and doing the reverse for the corresponding activation. Since activation is not quantized, there is no additional loss from activation, while the quantization loss from weight can be reduced.

After eval for INT4 checkpoint is done, we might find some task have a large accuracy drop compared to high precision baseline, in that case we can do a calibration for that task, with a few samples, tasks are selected from [lm-eval](https://github.com/EleutherAI/lm-eval\uation-harness/blob/main/lm_eval/tasks/README.md). You can follow [new task guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/new_task_guide.md) to add new tasks to lm-eval.

Expand All @@ -30,6 +32,9 @@ Examples:
# release AWQ-INT4 model, calibrated with a specific task
# with some calibration_limit (number of samples)
python quantize_and_upload.py --model_id Qwen/Qwen3-8B --quant AWQ-INT4 --push_to_hub --task bbh --calibration_limit 2

# release SMOOTHQUANT-INT8-INT8 model, calibrated with a specific task
python quantize_and_upload.py --model_id Qwen/Qwen3-8B --quant SMOOTHQUANT-INT8-INT8 --push_to_hub --task bbh --populate_model_card_template
```

### Update checkpoints for a different user_id (e.g. pytorch)
Expand Down
81 changes: 75 additions & 6 deletions .github/scripts/torchao_model_releases/quantize_and_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
ModuleFqnToConfig,
Expand All @@ -26,6 +27,7 @@
PerRow,
quantize_,
)
from torchao.prototype.smoothquant import SmoothQuantConfig


def _get_username():
Expand Down Expand Up @@ -242,6 +244,42 @@ def _untie_weights_and_save_locally(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id)
"""


_smoothquant_int8_int8_quant_code = """
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, quantize_
from torchao.prototype.smoothquant import SmoothQuantConfig

from torchao._models._eval import TransformerEvalWrapper
model = AutoModelForCausalLM.from_pretrained(
model_to_quantize,
device_map="auto",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

base_config = Int8DynamicActivationInt8WeightConfig()
quant_config = SmoothQuantConfig(base_config, step="prepare")
quantize_(
model,
quant_config,
)
TransformerEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=max_seq_length,
).run_eval(
tasks=tasks,
limit=calibration_limit,
)
quant_config = SmoothQuantConfig(base_config, step="convert")
quantize_(model, quant_config)

quantized_model = model
quant_config = SmoothQuantConfig(base_config, step="prepare_for_loading")
quantized_model.config.quantization_config = TorchAoConfig(quant_config)
"""


_awq_int4_quant_code = """
from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.prototype.awq import (
Expand Down Expand Up @@ -592,7 +630,7 @@ def _untie_weights_and_save_locally(model_id):
python -m executorch.examples.models.qwen3.convert_weights $(hf download {quantized_model}) pytorch_model_converted.bin
```

Once we have the checkpoint, we export it to ExecuTorch with a max_seq_length/max_context_length of 1024 to the XNNPACK backend as follows.
Once we have the checkpoint, we export it to ExecuTorch with a max_seq_length/max_context_length of 1024 to the XNNPACK backend as follows.

[TODO: fix config path in note where necessary]
(Note: ExecuTorch LLM export script requires config.json have certain key names. The correct config to use for the LLM export script is located at examples/models/qwen3/config/4b_config.json within the ExecuTorch repo.)
Expand Down Expand Up @@ -651,13 +689,15 @@ def quantize_and_upload(
"model.embed_tokens": _int8_int4_embedding_config,
}
),
"SMOOTHQUANT-INT8-INT8": Int8DynamicActivationInt8WeightConfig(),
}

quant_to_quant_code = {
"FP8": _fp8_quant_code,
"INT4": _int4_quant_code,
"INT8-INT4": _int8_int4_quant_code,
"AWQ-INT4": _awq_int4_quant_code,
"SMOOTHQUANT-INT8-INT8": _smoothquant_int8_int8_quant_code,
}

# preparation
Expand Down Expand Up @@ -697,6 +737,35 @@ def quantize_and_upload(
quantized_model = model
quant_config = AWQConfig(base_config, step="prepare_for_loading")
quantized_model.config.quantization_config = TorchAoConfig(quant_config)
elif quant == "SMOOTHQUANT-INT8-INT8":
model = AutoModelForCausalLM.from_pretrained(
model_to_quantize,
device_map="auto",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

base_config = Int8DynamicActivationInt8WeightConfig()
quant_config = SmoothQuantConfig(base_config, step="prepare")
quantize_(
model,
quant_config,
)
TransformerEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=max_seq_length,
).run_eval(
tasks=tasks,
limit=calibration_limit,
)
quant_config = SmoothQuantConfig(base_config, step="convert")
quantize_(model, quant_config)

quantized_model = model

load_config = SmoothQuantConfig(base_config, step="prepare_for_loading")
quantized_model.config.quantization_config = TorchAoConfig(load_config)
else:
# other quantization are integrated with `from_pretrained` in huggingface transformers
assert quant in quant_to_config, f"Unsupported quant option: {quant}"
Expand Down Expand Up @@ -812,7 +881,7 @@ def quantize_and_upload(
parser.add_argument(
"--quant",
type=str,
help="Quantization method. Options are FP8, INT4, INT8-INT4, AWQ-INT4",
help="Quantization method. Options are FP8, INT4, INT8-INT4, AWQ-INT4, SMOOTHQUANT-INT8-INT8",
)
parser.add_argument(
"--tasks",
Expand All @@ -824,14 +893,14 @@ def quantize_and_upload(
parser.add_argument(
"--calibration_limit",
type=int,
default=10,
help="Number of samples to use for calibration. Default is 10.",
default=128,
help="Number of samples to use for calibration. Default is 128.",
)
parser.add_argument(
"--max_seq_length",
type=int,
default=2048,
help="Maximum sequence length of examples to calibrate and evaluate model on. Default is 2048",
default=1024,
help="Maximum sequence length of examples to calibrate and evaluate model on. Default is 1024",
)
parser.add_argument(
"--push_to_hub",
Expand Down