Skip to content

Commit 85d51b2

Browse files
mgoinshanjiaz
authored andcommitted
Support DeepSeekV3-style block FP8 quantization
Signed-off-by: mgoin <[email protected]>
1 parent 0851638 commit 85d51b2

File tree

5 files changed

+107
-8
lines changed

5 files changed

+107
-8
lines changed

docs/guides/compression_schemes.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ PTQ is performed to reduce the precision of quantizable weights (e.g., linear la
1919
- Useful for speed ups in high QPS regimes or offline serving on vLLM.
2020
- Recommended for NVIDIA GPUs with compute capability >=9.0 (Hopper and Blackwell).
2121

22+
### [W8A8-FP8_BLOCK](../examples/quantization_w8a8_fp8/fp8_block_example.py)
23+
- Uses block-wise quantization to compress weights to FP8 in (commonly 128×128 tiles), and dynamic per-token-group (128) quantization for activations. Does not require calibration dataset. Activation quantization is carried out during inference on vLLM.
24+
2225
## Sparsification
2326
Sparsification reduces model complexity by pruning selected weight values to zero while retaining essential weights in a subset of parameters. Supported formats include:
2427

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from transformers import AutoModelForCausalLM, AutoTokenizer
2+
3+
from llmcompressor import oneshot
4+
from llmcompressor.modifiers.quantization import QuantizationModifier
5+
6+
MODEL_ID = "Qwen/Qwen3-0.6B"
7+
8+
# Load model.
9+
model = AutoModelForCausalLM.from_pretrained(
10+
MODEL_ID, device_map="auto", torch_dtype="auto"
11+
)
12+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13+
14+
# Configure the quantization algorithm and scheme.
15+
# In this case, we:
16+
# * quantize the weights to fp8 with per channel via ptq
17+
# * quantize the activations to fp8 with dynamic per token
18+
recipe = QuantizationModifier(
19+
targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"]
20+
)
21+
22+
# Apply quantization.
23+
oneshot(model=model, recipe=recipe)
24+
25+
# Confirm generations of the quantized model look sane.
26+
print("========== SAMPLE GENERATION ==============")
27+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
28+
output = model.generate(input_ids, max_new_tokens=20)
29+
print(tokenizer.decode(output[0]))
30+
print("==========================================")
31+
32+
# Save to disk in compressed-tensors format.
33+
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-BLOCK"
34+
model.save_pretrained(SAVE_DIR)
35+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,25 @@ def call_observer(
124124
updated_scale, updated_zero_point = observer(
125125
value, g_idx=g_idx, global_scale=global_scale
126126
)
127-
update_parameter_data(module, updated_scale, f"{base_name}_scale")
128-
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
127+
# register or update scale & zero_point parameters (supports block shapes)
128+
scale_name = f"{base_name}_scale"
129+
zp_name = f"{base_name}_zero_point"
130+
if not hasattr(module, scale_name) or getattr(module, scale_name).shape != updated_scale.shape:
131+
if hasattr(module, scale_name):
132+
delattr(module, scale_name)
133+
module.register_parameter(
134+
scale_name, torch.nn.Parameter(updated_scale.clone())
135+
)
136+
else:
137+
update_parameter_data(module, updated_scale, scale_name)
138+
if not hasattr(module, zp_name) or getattr(module, zp_name).shape != updated_zero_point.shape:
139+
if hasattr(module, zp_name):
140+
delattr(module, zp_name)
141+
module.register_parameter(
142+
zp_name, torch.nn.Parameter(updated_zero_point.clone())
143+
)
144+
else:
145+
update_parameter_data(module, updated_zero_point, zp_name)
129146

130147

131148
def update_weight_global_scale(module: Module):

src/llmcompressor/observers/base.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,30 @@ def get_qparams(
193193
)
194194

195195
elif self.quantization_args.strategy == QuantizationStrategy.BLOCK:
196-
# TODO (#1475) add support for block-wise quantization
197-
raise NotImplementedError(
198-
"Block-wise quantization is not yet supported, "
199-
"consider group-wise quantization instead. More info at "
200-
"https://github.com/vllm-project/llm-compressor/issues/1475"
201-
)
196+
# Block-wise quantization: one scale/zero_point per block of shape [block_rows, block_cols]
197+
rows, cols = observed.shape[:2]
198+
bs = self.quantization_args.block_structure
199+
if not (isinstance(bs, (list, tuple)) and len(bs) == 2 and all(isinstance(x, int) for x in bs)):
200+
raise ValueError(f"Invalid block_structure '{bs}'. Must be a list of two ints [rows, cols].")
201+
block_rows, block_cols = bs
202+
num_br = int(ceil(rows / block_rows))
203+
num_bc = int(ceil(cols / block_cols))
204+
# allocate per-block scale and zero_point
205+
self._scale = torch.empty((num_br, num_bc), dtype=observed.dtype, device=observed.device)
206+
self._zero_point = torch.empty((num_br, num_bc), dtype=observed.dtype, device=observed.device)
207+
# compute qparams for each block
208+
for i in range(num_br):
209+
r0 = i * block_rows
210+
r1 = min((i + 1) * block_rows, rows)
211+
for j in range(num_bc):
212+
c0 = j * block_cols
213+
c1 = min((j + 1) * block_cols, cols)
214+
# reduce across both dims to get one scale and zp per block
215+
scale_bp, zp_bp = self.calculate_qparams(
216+
observed[r0:r1, c0:c1], reduce_dims=(0, 1)
217+
)
218+
self._scale[i, j] = scale_bp
219+
self._zero_point[i, j] = zp_bp
202220

203221
return self._scale, self._zero_point
204222

tests/llmcompressor/modifiers/quantization/test_base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,32 @@ def q_config_kwargs(config_0, config_1):
3434
)
3535
)
3636

37+
@pytest.fixture
38+
def block_q_config_kwargs():
39+
return dict(
40+
config_groups=dict(
41+
group_block=dict(
42+
targets=["Linear"],
43+
input_activations=dict(
44+
num_bits=8, symmetric=True, strategy="group", group_size=128
45+
),
46+
weights=dict(
47+
num_bits=8,
48+
symmetric=True,
49+
strategy="block",
50+
block_structure=[128, 128],
51+
),
52+
),
53+
)
54+
)
55+
56+
def test_block_strategy_parsing(block_q_config_kwargs):
57+
modifier = GPTQModifier(**block_q_config_kwargs)
58+
resolved = modifier.resolve_quantization_config()
59+
w_scheme = resolved.config_groups["group_block"].weights
60+
assert w_scheme.strategy == "block"
61+
assert w_scheme.block_structure == [128, 128]
62+
3763

3864
@pytest.mark.parametrize(
3965
"has_actorder,actorder,config_0,config_1,expected_0,expected_1",

0 commit comments

Comments
 (0)