7
7
import unittest
8
8
9
9
import torch
10
+ import torch .nn as nn
11
+ import torch .nn .functional as F
10
12
from torch .testing ._internal .common_utils import (
11
13
TestCase ,
14
+ instantiate_parametrized_tests ,
15
+ parametrize ,
12
16
run_tests ,
13
17
)
14
18
15
- from torchao .quantization import (
16
- Int4WeightOnlyConfig ,
17
- quantize_ ,
18
- )
19
+ from torchao .prototype .moe_quant .utils import MoEQuantConfig
20
+ from torchao .quantization import Int4WeightOnlyConfig , quantize_
19
21
from torchao .quantization .utils import compute_error
20
- from torchao .utils import (
21
- TORCH_VERSION_AT_LEAST_2_8 ,
22
- is_sm_at_least_90 ,
23
- )
22
+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_8 , is_sm_at_least_90
23
+
24
+
25
+ class Experts (nn .Module ):
26
+ def __init__ (
27
+ self ,
28
+ num_local_experts : int ,
29
+ dim : int ,
30
+ hidden_dim : int ,
31
+ dtype : torch .dtype ,
32
+ device : torch .device ,
33
+ ) -> None :
34
+ super ().__init__ ()
35
+
36
+ self .num_local_experts = num_local_experts
37
+ self .dim = dim
38
+
39
+ self .w1 : nn .Parameter = nn .Parameter (
40
+ torch .randn (
41
+ num_local_experts ,
42
+ dim ,
43
+ hidden_dim ,
44
+ dtype = dtype ,
45
+ device = device ,
46
+ )
47
+ )
48
+
49
+ self .w2 : nn .Parameter = nn .Parameter (
50
+ torch .randn (
51
+ num_local_experts ,
52
+ hidden_dim ,
53
+ dim ,
54
+ dtype = dtype ,
55
+ device = device ,
56
+ )
57
+ )
58
+
59
+ self .w3 : nn .Parameter = nn .Parameter (
60
+ torch .randn (
61
+ num_local_experts ,
62
+ dim ,
63
+ hidden_dim ,
64
+ dtype = dtype ,
65
+ device = device ,
66
+ )
67
+ )
68
+
69
+ def forward (
70
+ self ,
71
+ routed_in_egD : torch .Tensor , # noqa: N803
72
+ ) -> torch .Tensor :
73
+ e = self .num_local_experts
74
+ D = self .dim
75
+
76
+ x_egD = routed_in_egD .view (e , - 1 , D )
77
+
78
+ middle_out_egF = F .silu (torch .bmm (x_egD , self .w1 )) * torch .bmm (x_egD , self .w3 )
79
+ out_egD = torch .bmm (middle_out_egF , self .w2 )
80
+ out_egD = out_egD .view (- 1 , D )
81
+
82
+ return out_egD
24
83
25
84
26
85
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
@@ -61,9 +120,9 @@ def test_slice(self):
61
120
quantize_ (dummy , self .config )
62
121
weight1 = dummy .weight .narrow (0 , 0 , 64 )
63
122
weight2 = dummy .weight .narrow (1 , 0 , 128 )
64
- self .assertEqual (weight1 ._data , dummy .weight ._data .narrow (0 , 0 , 64 ))
123
+ self .assertEqual (weight1 .qdata , dummy .weight .qdata .narrow (0 , 0 , 64 ))
65
124
self .assertEqual (weight1 .scale , dummy .weight .scale .narrow (1 , 0 , 64 ))
66
- self .assertEqual (weight2 ._data , dummy .weight ._data .narrow (1 , 0 , 64 ))
125
+ self .assertEqual (weight2 .qdata , dummy .weight .qdata .narrow (1 , 0 , 64 ))
67
126
self .assertEqual (weight2 .scale , dummy .weight .scale .narrow (0 , 0 , 1 ))
68
127
69
128
# check for sliced weight, before and after float8 quantization
@@ -80,31 +139,62 @@ def test_slice(self):
80
139
res = dummy (input )
81
140
assert compute_error (res , res_ref ) > 15
82
141
83
- def test_slice_and_copy_ (self ):
142
+ def test_slice_preserves_aliasing (self ):
143
+ config = self .config
84
144
l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
85
145
l .weight = torch .nn .Parameter (
86
146
torch .zeros (1024 , 1024 , dtype = torch .bfloat16 , device = "cuda" )
87
147
)
88
- quantize_ (l , self . config )
148
+ quantize_ (l , config )
89
149
param = l .weight
90
150
param_data = param .data
91
151
param_data = param_data .narrow (0 , 0 , 512 )
92
- assert param .data ._data .data_ptr () == param_data ._data .data_ptr ()
152
+ # Making sure the aliasing is preserved in sliced quantized Tensor
153
+ assert param .data .qdata .data_ptr () == param_data .qdata .data_ptr ()
93
154
assert param .data .scale .data_ptr () == param_data .scale .data_ptr ()
94
- assert param .data .zero_point .data_ptr () == param_data .zero_point .data_ptr ()
95
- orig_value = param .data ._data [0 ][0 ].item ()
96
155
97
- # dummy_l has random input (shouldn't be 0)
156
+ def test_slice_and_copy_similar_to_vllm (self ):
157
+ # making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly
158
+ # the test is similar to the linked code, but with some hardcoded arguments
159
+ # and does not use tensor parallelism
160
+
161
+ dtype = torch .bfloat16
162
+ device = "cuda"
163
+ config = self .config
164
+ l = torch .nn .Linear (1024 , 1024 , device = "cuda" , dtype = dtype )
165
+ quantize_ (l , config )
166
+
167
+ # high level, we do a narrow for both param.data and the loaded_weights
168
+ # and do inplace copy_ to copy from the loaded_weights into param.data
169
+
170
+ # simulate loaded_weight
98
171
dummy_l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
99
- quantize_ (dummy_l , self .config )
100
- quantized = dummy_l .weight
101
- quantized = quantized .narrow (0 , 0 , 512 )
172
+ # making the weight different
173
+ dummy_l .weight = torch .nn .Parameter (
174
+ dummy_l .weight + 2 * torch .randn (1024 , 1024 , device = device , dtype = dtype ),
175
+ requires_grad = False ,
176
+ )
177
+ quantize_ (dummy_l , config )
102
178
103
- param_data .copy_ (quantized )
179
+ output_dim = 0
180
+ shard_size = 512
181
+ for tp_rank in [0 , 1 ]:
182
+ start_idx = tp_rank * shard_size
183
+ param = l .weight
184
+ param_data = param .data
185
+ param_data = param_data .narrow (output_dim , start_idx , shard_size )
186
+ orig_value = param_data .qdata [0 ][0 ].item ()
187
+ loaded_weight = dummy_l .weight
188
+ loaded_weight = loaded_weight .narrow (output_dim , start_idx , shard_size )
104
189
105
- # making sure param.data is updated
106
- assert param .data ._data [0 ][0 ] != orig_value
190
+ # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
191
+ assert orig_value != loaded_weight .qdata [0 ][0 ]
192
+ param_data .copy_ (loaded_weight )
193
+ # making sure param.data is updated to loaded_weight
194
+ assert param_data .qdata [0 ][0 ] == loaded_weight .qdata [0 ][0 ]
195
+ assert torch .equal (param_data .scale , loaded_weight .scale )
107
196
197
+ @unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
108
198
def test_bmm (self ):
109
199
class M (torch .nn .Module ):
110
200
def __init__ (self , weight ):
@@ -126,20 +216,213 @@ def forward(self, x):
126
216
quantized = m (input )
127
217
self .assertTrue (compute_error (original , quantized ) > 18 )
128
218
129
- def test_to_device (self ):
219
+ @parametrize (
220
+ "sizes" ,
221
+ [
222
+ ((128 ,), 256 , 128 ),
223
+ ((32 , 128 ), 64 , 256 ),
224
+ ((2 , 32 , 128 ), 64 , 256 ),
225
+ ],
226
+ )
227
+ def test_to_device (self , sizes ):
228
+ config = self .config
229
+ M , N , K = sizes
230
+ dtype = torch .bfloat16
130
231
for device in self .GPU_DEVICES :
131
- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
132
- quantize_ (linear , self .config )
232
+ input_tensor = torch .randn (* M , K , dtype = dtype , device = device )
233
+ linear = torch .nn .Linear (K , N , dtype = dtype )
234
+ quantize_ (linear , config )
133
235
linear .to (device )
236
+ linear (input_tensor )
134
237
135
- linear = torch .nn .Linear (128 , 256 , dtype = torch . bfloat16 )
136
- quantize_ (linear , self . config )
238
+ linear = torch .nn .Linear (K , N , dtype = dtype )
239
+ quantize_ (linear , config )
137
240
linear .to (device = device )
241
+ linear (input_tensor )
138
242
139
- linear = torch .nn .Linear (128 , 256 , dtype = torch . bfloat16 )
140
- quantize_ (linear , self . config )
243
+ linear = torch .nn .Linear (K , N , dtype = dtype )
244
+ quantize_ (linear , config )
141
245
linear .to (device )
246
+ linear (input_tensor )
247
+
248
+ @parametrize (
249
+ "sizes" ,
250
+ [
251
+ ((128 ,), 256 , 128 ),
252
+ ((32 , 128 ), 64 , 256 ),
253
+ ((2 , 32 , 128 ), 64 , 256 ),
254
+ ],
255
+ )
256
+ def test_cat (self , sizes ):
257
+ config = self .config
258
+ dtype = torch .bfloat16
259
+ device = "cuda"
260
+ M , N , K = sizes
261
+ linear1 = torch .nn .Linear (K , N , dtype = dtype , device = device )
262
+ linear2 = torch .nn .Linear (K , N , dtype = dtype , device = device )
263
+ input_cat1 = torch .randn (* M , K , dtype = dtype , device = device )
264
+
265
+ cat_weight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
266
+ dummy_linear1 = torch .nn .Linear (K , N , bias = False , dtype = dtype , device = device )
267
+
268
+ dummy_linear1 .weight = torch .nn .Parameter (cat_weight1 )
269
+ quantize_ (dummy_linear1 , config )
270
+
271
+ quantize_ (linear1 , config )
272
+ quantize_ (linear2 , config )
273
+
274
+ cat_qweight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
275
+ self .assertTrue (cat_qweight1 .shape , (2 * N , K ))
276
+ self .assertEqual (
277
+ dummy_linear1 .weight .qdata ,
278
+ cat_qweight1 .qdata ,
279
+ )
280
+ self .assertEqual (
281
+ dummy_linear1 .weight .scale ,
282
+ cat_qweight1 .scale ,
283
+ )
284
+ self .assertEqual (
285
+ dummy_linear1 .weight .zero_point ,
286
+ cat_qweight1 .zero_point ,
287
+ )
288
+
289
+ # making sure cat_qweight1 can be used for inference
290
+ dummy_linear1 .weight = torch .nn .Parameter (cat_qweight1 , requires_grad = False )
291
+ dummy_linear1 (input_cat1 )
292
+
293
+ # align the scale and zero_point before concatenation
294
+ linear2 .weight .scale = linear1 .weight .scale
295
+ linear2 .weight .zero_point = linear1 .weight .zero_point
296
+ cat_qweight2 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 1 )
297
+ self .assertTrue (cat_qweight2 .shape , (N , 2 * K ))
298
+ ref_data = torch .cat (
299
+ [
300
+ linear1 .weight .qdata ,
301
+ linear2 .weight .qdata ,
302
+ ],
303
+ dim = 1 ,
304
+ )
305
+ ref_scale = linear1 .weight .scale
306
+ self .assertEqual (cat_qweight2 .qdata , ref_data )
307
+ self .assertEqual (cat_qweight2 .scale , ref_scale )
308
+
309
+ def test_moe_weight_reshape_ops (self ):
310
+ """This is testing the op call sequence in saving and loading quantization
311
+ checkpoints in llama-models for llama4
312
+ (https://github.com/meta-llama/llama-models/tree/main/models/llama4)
313
+ """
314
+ # only per row quantization is supported for bmm
315
+ dtype = torch .bfloat16
316
+ device = "cuda"
317
+
318
+ bmm_config = self .config
319
+ moe_config = MoEQuantConfig (bmm_config )
320
+
321
+ batch_size = 4
322
+ num_experts = 2
323
+ input_dim = 64
324
+ dim = 128
325
+ hidden_dim = 256
326
+
327
+ moe1 = Experts (num_experts , dim , hidden_dim , dtype , device )
328
+ moe2 = Experts (num_experts , dim , hidden_dim , dtype , device )
329
+ moe_combined = Experts (num_experts , dim , 2 * hidden_dim , dtype , device )
330
+ input = torch .randn (batch_size , input_dim , dim , dtype = dtype , device = device )
331
+
332
+ moes = [moe1 , moe2 ]
333
+
334
+ for moe in moes :
335
+ moe (input )
336
+
337
+ def filter_fn (module , fqn ):
338
+ return isinstance (module , Experts )
339
+
340
+ # need to transpose before quantizing
341
+ moe .w1 = torch .nn .Parameter (
342
+ moe .w1 .transpose (1 , 2 ).contiguous (), requires_grad = False
343
+ )
344
+ moe .w2 = torch .nn .Parameter (
345
+ moe .w2 .transpose (1 , 2 ).contiguous (), requires_grad = False
346
+ )
347
+ moe .w3 = torch .nn .Parameter (
348
+ moe .w3 .transpose (1 , 2 ).contiguous (), requires_grad = False
349
+ )
350
+
351
+ quantize_ (moe , moe_config , filter_fn = filter_fn )
352
+
353
+ before = moe (input )
354
+
355
+ # transposing for resharding support since only 2D resharding is supported
356
+ new_last_dim = moe .w1 .shape [- 2 ]
357
+ moe .w1 = torch .nn .Parameter (
358
+ moe .w1 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
359
+ )
360
+ new_last_dim = moe .w2 .shape [- 2 ]
361
+ moe .w2 = torch .nn .Parameter (
362
+ moe .w2 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
363
+ )
364
+ new_last_dim = moe .w3 .shape [- 2 ]
365
+ moe .w3 = torch .nn .Parameter (
366
+ moe .w3 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
367
+ )
368
+
369
+ moe .w1 = torch .nn .Parameter (
370
+ moe .w1 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
371
+ requires_grad = False ,
372
+ )
373
+ moe .w2 = torch .nn .Parameter (
374
+ moe .w2 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
375
+ requires_grad = False ,
376
+ )
377
+ moe .w3 = torch .nn .Parameter (
378
+ moe .w3 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
379
+ requires_grad = False ,
380
+ )
381
+
382
+ # transpose again to recover the original weights
383
+ moe .w1 = torch .nn .Parameter (moe .w1 .transpose (1 , 2 ), requires_grad = False )
384
+ moe .w2 = torch .nn .Parameter (moe .w2 .transpose (1 , 2 ), requires_grad = False )
385
+ moe .w3 = torch .nn .Parameter (moe .w3 .transpose (1 , 2 ), requires_grad = False )
386
+
387
+ after = moe (input )
388
+ self .assertEqual (before , after )
389
+
390
+ state_dicts = [moe1 .state_dict (), moe2 .state_dict ()]
391
+ # align the scale parameter so they can be concatenated
392
+ for key in ["w1" , "w2" , "w3" ]:
393
+ weights = [st [key ] for st in state_dicts ]
394
+ for i in range (1 , len (weights )):
395
+ weights [i ].scale = weights [0 ].scale
396
+ weights [i ].zero_point = weights [0 ].zero_point
397
+
398
+ def process_key (key : str ) -> torch .Tensor :
399
+ tensors = [s [key ] for s in state_dicts ]
400
+ # Note: we have a hacky implementation for cat in user codebase
401
+ # since it is not implemented correctly before
402
+ if key == "w2" :
403
+ return torch .cat (tensors , dim = - 1 )
404
+ else :
405
+ return torch .cat (tensors , dim = - 2 )
406
+
407
+ new_state_dict = {}
408
+ for key in ["w1" , "w2" , "w3" ]:
409
+ new_state_dict [key ] = process_key (key )
410
+
411
+ moe_combined .w1 = torch .nn .Parameter (
412
+ moe_combined .w1 .transpose (1 , 2 ), requires_grad = False
413
+ )
414
+ moe_combined .w2 = torch .nn .Parameter (
415
+ moe_combined .w2 .transpose (1 , 2 ), requires_grad = False
416
+ )
417
+ moe_combined .w3 = torch .nn .Parameter (
418
+ moe_combined .w3 .transpose (1 , 2 ), requires_grad = False
419
+ )
420
+ moe_combined .load_state_dict (new_state_dict , assign = True )
421
+ # make sure it runs
422
+ moe_combined (input )
423
+
142
424
425
+ instantiate_parametrized_tests (TestInt4Tensor )
143
426
144
427
if __name__ == "__main__" :
145
428
run_tests ()
0 commit comments