10
10
)
11
11
from compressed_tensors .quantization .lifecycle .forward import forward_quantize
12
12
from compressed_tensors .quantization .utils import is_kv_cache_quant_scheme
13
- from compressed_tensors .utils import align_module_device , update_parameter_data
13
+ from compressed_tensors .utils import (
14
+ align_module_device ,
15
+ delete_offload_parameter ,
16
+ update_offload_parameter ,
17
+ )
14
18
from loguru import logger
15
19
from torch .nn import Module
16
20
@@ -116,7 +120,7 @@ def call_observer(
116
120
value ,
117
121
should_calculate_gparam = True ,
118
122
)
119
- update_parameter_data (module , global_scale , f"{ base_name } _global_scale" )
123
+ update_offload_parameter (module , f"{ base_name } _global_scale" , global_scale )
120
124
else :
121
125
global_scale = getattr (module , f"{ base_name } _global_scale" , None )
122
126
@@ -127,22 +131,21 @@ def call_observer(
127
131
# register or update scale & zero_point parameters (supports block shapes)
128
132
scale_name = f"{ base_name } _scale"
129
133
zp_name = f"{ base_name } _zero_point"
130
- if not hasattr (module , scale_name ) or getattr (module , scale_name ).shape != updated_scale .shape :
131
- if hasattr (module , scale_name ):
132
- delattr (module , scale_name )
133
- module .register_parameter (
134
- scale_name , torch .nn .Parameter (updated_scale .clone ())
135
- )
136
- else :
137
- update_parameter_data (module , updated_scale , scale_name )
138
- if not hasattr (module , zp_name ) or getattr (module , zp_name ).shape != updated_zero_point .shape :
139
- if hasattr (module , zp_name ):
140
- delattr (module , zp_name )
141
- module .register_parameter (
142
- zp_name , torch .nn .Parameter (updated_zero_point .clone ())
143
- )
144
- else :
145
- update_parameter_data (module , updated_zero_point , zp_name )
134
+ for name , value in [
135
+ (scale_name , updated_scale ),
136
+ (zp_name , updated_zero_point ),
137
+ ]:
138
+ if (
139
+ not hasattr (module , name )
140
+ or getattr (module , name ).shape != value .shape
141
+ ):
142
+ if hasattr (module , name ):
143
+ delete_offload_parameter (module , name )
144
+ module .register_offload_parameter (
145
+ name , torch .nn .Parameter (value .clone (), requires_grad = False )
146
+ )
147
+ else :
148
+ update_offload_parameter (module , name , value )
146
149
147
150
148
151
def update_weight_global_scale (module : Module ):
@@ -273,8 +276,8 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
273
276
kv_cache = getattr (module , "kv_cache" )
274
277
k_scale = kv_cache .k_scales [module .layer_idx ]
275
278
v_scale = kv_cache .v_scales [module .layer_idx ]
276
- update_parameter_data (module , k_scale , KVCacheScaleType .KEY .value )
277
- update_parameter_data (module , v_scale , KVCacheScaleType .VALUE .value )
279
+ update_offload_parameter (module , KVCacheScaleType .KEY .value , k_scale )
280
+ update_offload_parameter (module , KVCacheScaleType .VALUE .value , v_scale )
278
281
279
282
280
283
def initialize_quantized_kv_cache (module : Module ):
0 commit comments