Commit 8747bae
[Transform] Spinquant with R1 and R2 (#1615)
## Purpose ##
* Enable offline spinquant-style transforms
## Prerequisites ##
* vllm-project/compressed-tensors#370
* vllm-project/compressed-tensors#412
* vllm-project/compressed-tensors#414
## Changes ##
* Added `spinquant_example.py` to examples folder
* Added `SpinQuantModifier` which handles the construction of a
spinquant-style transform config
## Testing ##
* Added modifier serialization and correctness tests
## Evaluation ##
Using this branch, and [the original SpinQuant
code](https://github.com/facebookresearch/SpinQuant), we see very
similar results for `meta-llama/Llama-3.2-1B-Instruct` with W4A16
quantization. Results are equivalent in hf (in-memory vs serialized and
re-loaded), and very similar in vllm. The symmetric scales calculation
in `llm-compressor` is slightly different than original SpinQuant paper,
which uses the original GPTQ implementation. When this is swapped in,
results are consistent, with hadamard improving results on `gsm8k_llama`
and `arc_challenge_llama`:
Scheme | Impl | gsm8k | gsm8k_llama | arc_challenge_llama
-- | -- | -- | -- | --
Hadamard+W4A16 | LC | 0.2403 | 0.2835 | 0.5262
W4A16 | LC | 0.1964 | 0.1933 | 0.4781
Hadamard+W4A16 | LC+SQscales | 0.1721 | 0.2183 | 0.485
W4A16 | LC+SQscales | 0.207 | 0.1706 | 0.4498
Hadamard+W4A16 | SQ | 0.1736 | 0.2282 | 0.4807
W4A16 | SQ | 0.1986 | 0.1774 | 0.4489
To run LC+SQScales, change [this line in
CT](https://github.com/neuralmagic/compressed-tensors/blob/b2df366797b00330ec765f5891dde14e4cc74c9d/src/compressed_tensors/quantization/utils/helpers.py#L111)
from
```python
scales = max_val_pos / (float(bit_range) / 2)
```
to
```python
scales = max_val_pos / (float(bit_max))
```
<details>
<summary>The following python script was used to generate these
results</summary>
Clone SpinQuant repo and paste this in the top-level directory:
```python
# coding=utf-8
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from typing import Literal
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from torch import nn
import lm_eval
from transformers import LlamaForCausalLM, AutoTokenizer
import transformers
from train_utils.main import prepare_model
from train_utils.modeling_llama_quant import LlamaForCausalLM as LlamaForCausalLMQuant
from utils.hadamard_utils import random_hadamard_matrix, hadamard_matrix
from utils.process_args import process_args_ptq
# model_id = "meta-llama/Llama-3.1-8B-Instruct"
# model_id = "meta-llama/Llama-3.2-3B-Instruct"
model_id = "meta-llama/Llama-3.2-1B-Instruct"
dtype = torch.bfloat16
class RotateModule(nn.Module):
def __init__(self, R_init):
super(RotateModule, self).__init__()
self.weight = nn.Parameter(R_init.to(torch.float32).to(torch.device("cuda")))
def forward(self, x, transpose=False):
if transpose:
return x @ self.weight
else:
return self.weight @ x
def get_sq_model(
r1r2=Literal["eye", "random-hadamard", "hadamard"],
w_bits=Literal[4, 16],
w_clip: bool = False,
) -> LlamaForCausalLMQuant:
model_args, training_args, ptq_args = process_args_ptq()
model_args.input_model = model_id
if w_bits == 4:
ptq_args.w_bits = 4
ptq_args.w_groupsize = 128
ptq_args.w_rtn = True # if False, GPTQ is used
ptq_args.w_clip = w_clip
ptq_args.a_bits = 16
ptq_args.k_bits = 16
ptq_args.v_bits = 16
print("=======ARGS=======", ptq_args)
config = transformers.AutoConfig.from_pretrained(model_args.input_model)
# Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens
process_word_embeddings = False
if config.tie_word_embeddings:
config.tie_word_embeddings = False
process_word_embeddings = True
model = LlamaForCausalLMQuant.from_pretrained(
pretrained_model_name_or_path=model_args.input_model,
config=config,
torch_dtype=dtype,
device_map="cuda",
)
if process_word_embeddings:
model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
model = prepare_model(ptq_args, model)
for param in model.parameters():
param.requires_grad = False
match r1r2:
case "eye":
R1 = torch.eye(model.config.hidden_size, device="cuda")
case "random-hadamard":
R1 = random_hadamard_matrix(model.config.hidden_size, "cuda")
case _:
R1 = hadamard_matrix(model.config.hidden_size, "cuda")
model.R1 = RotateModule(R1)
for i in range(model.config.num_hidden_layers):
# Each head dim = 128 for Llama model
match r1r2:
case "eye":
R2 = torch.eye(
model.config.hidden_size // model.config.num_attention_heads,
device="cuda",
)
case "random-hadamard":
R2 = random_hadamard_matrix(
model.config.hidden_size // model.config.num_attention_heads, "cuda"
)
case _:
R2 = hadamard_matrix(
model.config.hidden_size // model.config.num_attention_heads, "cuda"
)
model.model.layers[i].self_attn.R2 = RotateModule(R2)
model.config.use_cache = False
return model
def get_lc_model(
r1r2=Literal["eye", "random-hadamard", "hadamard"], w_bits=Literal[4, 16]
) -> LlamaForCausalLM:
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.modifiers.transform import SpinQuantModifier
model = LlamaForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_id,
torch_dtype=dtype,
device_map="cuda",
)
recipe = [
SpinQuantModifier(
rotations=[] if r1r2 == "eye" else ["R1", "R2"],
transform_type="hadamard",
)
]
if w_bits == 4:
recipe.append(
QuantizationModifier(
targets="Linear",
scheme="W4A16",
ignore=["lm_head"],
)
)
oneshot(
model=model,
recipe=recipe,
pipeline="datafree",
log_dir=None,
)
return model
if __name__ == "__main__":
for scales_impl in ["sq_min_hack", "lc_min_hack"]:
for r1r2 in ["eye", "hadamard"]:
for sq_lc in ["sq", "lc"]:
w_bits = 4
os.environ["SCALES_IMPL"] = scales_impl
model = (
get_sq_model(r1r2=r1r2, w_bits=w_bits)
if sq_lc == "sq"
else get_lc_model(r1r2=r1r2, w_bits=w_bits)
).to("cuda")
SAVE_DIR = model_id.split("/")[1] + f"-{scales_impl}-{r1r2}-w4a16"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=True
)
tokenizer.save_pretrained(SAVE_DIR)
del model
del tokenizer
torch.cuda.empty_cache()
results = lm_eval.simple_evaluate(
# 1) hf in-memory
# model=lm_eval.models.huggingface.HFLM(
# pretrained=model,
# batch_size=32,
# add_bos_token=False,
# ),
# 1/)
# 2) vllm serialized
model="vllm",
model_args={
"pretrained": SAVE_DIR,
"add_bos_token": False,
"dtype": "auto",
"max_model_len": 4096,
"gpu_memory_utilization": 0.5,
"enable_chunked_prefill": True,
},
# 2/)
# 3) hf serialized
# model="hf",
# model_args={
# "pretrained": SAVE_DIR,
# "add_bos_token": False,
# "dtype": "auto",
# },
# device="cuda",
# 3/)
tasks=["gsm8k_llama", "gsm8k", "arc_challenge_llama"],
num_fewshot=8,
batch_size=32,
apply_chat_template=True,
fewshot_as_multiturn=True,
)
print(
f"RESULTS, {model_id} {sq_lc} R1R2 {r1r2} W_BITS {w_bits} SCALEIMPL {scales_impl}"
)
print(lm_eval.utils.make_table(results))
```
</details>
## Follow Ups ##
* Infer data free pipeline, even if a transform modifier is included
* Rotations R3 and R4
* Modify example to use GPTQ once basic evaluation has been performed
---------
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Co-authored-by: Kyle Sayers <[email protected]>1 parent d5a6a4b commit 8747bae
File tree
10 files changed
+497
-0
lines changed- examples/transform
- src/llmcompressor
- modeling
- modifiers/transform
- spinquant
- pipelines/data_free
- tests/llmcompressor/modifiers/transform
10 files changed
+497
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
| 3 | + | |
3 | 4 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
Lines changed: 3 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
Lines changed: 76 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
0 commit comments