Skip to content

Commit 2ff4a63

Browse files
committed
address reviews
Signed-off-by: shanjiaz <[email protected]>
1 parent 8b49283 commit 2ff4a63

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from compressed_tensors.utils import (
1414
align_module_device,
1515
delete_offload_parameter,
16+
register_offload_parameter,
1617
update_offload_parameter,
1718
)
1819
from loguru import logger
@@ -131,21 +132,23 @@ def call_observer(
131132
# register or update scale & zero_point parameters (supports block shapes)
132133
scale_name = f"{base_name}_scale"
133134
zp_name = f"{base_name}_zero_point"
134-
for name, value in [
135+
for name, param_value in [
135136
(scale_name, updated_scale),
136137
(zp_name, updated_zero_point),
137138
]:
138139
if (
139140
not hasattr(module, name)
140-
or getattr(module, name).shape != value.shape
141+
or getattr(module, name).shape != param_value.shape
141142
):
142143
if hasattr(module, name):
143144
delete_offload_parameter(module, name)
144-
register_offload_parameter(module
145-
name, torch.nn.Parameter(value.clone(), requires_grad=False)
145+
register_offload_parameter(
146+
module,
147+
name,
148+
torch.nn.Parameter(param_value.clone(), requires_grad=False),
146149
)
147150
else:
148-
update_offload_parameter(module, name, value)
151+
update_offload_parameter(module, name, param_value)
149152

150153

151154
def update_weight_global_scale(module: Module):

src/llmcompressor/observers/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,22 @@ def get_qparams(
209209
block_rows, block_cols = bs
210210
num_br = int(ceil(rows / block_rows))
211211
num_bc = int(ceil(cols / block_cols))
212+
212213
# allocate per-block scale and zero_point
213214
self._scale = torch.empty(
214215
(num_br, num_bc), dtype=observed.dtype, device=observed.device
215216
)
217+
218+
# Use same dtype logic as GROUP strategy for zero_point
219+
if is_fp4(quantization_args=self.quantization_args):
220+
zp_dtype = FP8_E4M3_DATA.dtype
221+
else:
222+
zp_dtype = self.quantization_args.pytorch_dtype()
223+
216224
self._zero_point = torch.empty(
217-
(num_br, num_bc), dtype=observed.dtype, device=observed.device
225+
(num_br, num_bc), dtype=zp_dtype, device=observed.device
218226
)
227+
219228
# compute qparams for each block
220229
for i in range(num_br):
221230
r0 = i * block_rows

0 commit comments

Comments
 (0)