10
10
from typing import Tuple
11
11
12
12
import torch
13
- import torch .nn as nn
14
- import torch .nn .functional as F
15
13
from torch .testing ._internal import common_utils
16
14
from torch .testing ._internal .common_utils import (
17
- TestCase ,
18
15
run_tests ,
19
16
)
20
17
21
- from torchao .prototype .moe_quant .utils import MoEQuantConfig
22
18
from torchao .quantization import (
23
19
Float8DynamicActivationFloat8WeightConfig ,
24
20
Float8WeightOnlyConfig ,
28
24
)
29
25
from torchao .quantization .quantize_ .common import KernelPreference
30
26
from torchao .quantization .utils import compute_error
27
+ from torchao .testing .utils import TorchAOIntegrationTestCase
31
28
from torchao .utils import (
32
29
TORCH_VERSION_AT_LEAST_2_8 ,
33
30
_is_fbgemm_genai_gpu_available ,
39
36
torch ._dynamo .config .cache_size_limit = 128
40
37
41
38
42
- class Experts (nn .Module ):
43
- def __init__ (
44
- self ,
45
- num_local_experts : int ,
46
- dim : int ,
47
- hidden_dim : int ,
48
- dtype : torch .dtype ,
49
- device : torch .device ,
50
- ) -> None :
51
- super ().__init__ ()
52
-
53
- self .num_local_experts = num_local_experts
54
- self .dim = dim
55
-
56
- self .w1 : nn .Parameter = nn .Parameter (
57
- torch .randn (
58
- num_local_experts ,
59
- dim ,
60
- hidden_dim ,
61
- dtype = dtype ,
62
- device = device ,
63
- )
64
- )
65
-
66
- self .w2 : nn .Parameter = nn .Parameter (
67
- torch .randn (
68
- num_local_experts ,
69
- hidden_dim ,
70
- dim ,
71
- dtype = dtype ,
72
- device = device ,
73
- )
74
- )
75
-
76
- self .w3 : nn .Parameter = nn .Parameter (
77
- torch .randn (
78
- num_local_experts ,
79
- dim ,
80
- hidden_dim ,
81
- dtype = dtype ,
82
- device = device ,
83
- )
84
- )
85
-
86
- def forward (
87
- self ,
88
- routed_in_egD : torch .Tensor , # noqa: N803
89
- ) -> torch .Tensor :
90
- e = self .num_local_experts
91
- D = self .dim
92
-
93
- x_egD = routed_in_egD .view (e , - 1 , D )
94
-
95
- middle_out_egF = F .silu (torch .bmm (x_egD , self .w1 )) * torch .bmm (x_egD , self .w3 )
96
- out_egD = torch .bmm (middle_out_egF , self .w2 )
97
- out_egD = out_egD .view (- 1 , D )
98
-
99
- return out_egD
100
-
101
-
102
39
class ToyLinearModel (torch .nn .Module ):
103
40
def __init__ (self , in_features , out_features ):
104
41
super ().__init__ ()
@@ -115,7 +52,7 @@ def forward(self, x):
115
52
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
116
53
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
117
54
@unittest .skipIf (not is_sm_at_least_89 (), "Need sm89+" )
118
- class TestFloat8Tensor (TestCase ):
55
+ class TestFloat8Tensor (TorchAOIntegrationTestCase ):
119
56
def setUp (self ):
120
57
self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
121
58
@@ -338,45 +275,8 @@ def test_slice_preserves_aliasing(self, granularity):
338
275
339
276
@common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
340
277
def test_slice_and_copy_similar_to_vllm (self , granularity ):
341
- # making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly
342
- # the test is similar to the linked code, but with some hardcoded arguments
343
- # and does not use tensor parallelism
344
-
345
- dtype = torch .bfloat16
346
- device = "cuda"
347
278
config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
348
- l = torch .nn .Linear (1024 , 1024 , device = "cuda" , dtype = dtype )
349
- quantize_ (l , config )
350
-
351
- # high level, we do a narrow for both param.data and the loaded_weights
352
- # and do inplace copy_ to copy from the loaded_weights into param.data
353
-
354
- # simulate loaded_weight
355
- dummy_l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
356
- # making the weight different
357
- dummy_l .weight = torch .nn .Parameter (
358
- dummy_l .weight + 2 * torch .randn (1024 , 1024 , device = device , dtype = dtype ),
359
- requires_grad = False ,
360
- )
361
- quantize_ (dummy_l , config )
362
-
363
- output_dim = 0
364
- shard_size = 512
365
- for tp_rank in [0 , 1 ]:
366
- start_idx = tp_rank * shard_size
367
- param = l .weight
368
- param_data = param .data
369
- param_data = param_data .narrow (output_dim , start_idx , shard_size )
370
- orig_value = param_data .qdata [0 ][0 ].item ()
371
- loaded_weight = dummy_l .weight
372
- loaded_weight = loaded_weight .narrow (output_dim , start_idx , shard_size )
373
-
374
- # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
375
- assert orig_value != loaded_weight .qdata [0 ][0 ]
376
- param_data .copy_ (loaded_weight )
377
- # making sure param.data is updated to loaded_weight
378
- assert param_data .qdata [0 ][0 ] == loaded_weight .qdata [0 ][0 ]
379
- assert param_data .scale [0 ] == loaded_weight .scale [0 ]
279
+ self ._test_slice_and_copy_similar_to_vllm (config )
380
280
381
281
@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
382
282
def test_bmm (self ):
@@ -492,122 +392,9 @@ def test_cat(self, granularity, sizes):
492
392
493
393
@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
494
394
def test_moe_weight_reshape_ops (self ):
495
- """This is testing the op call sequence in saving and loading quantization
496
- checkpoints in llama-models for llama4
497
- (https://github.com/meta-llama/llama-models/tree/main/models/llama4)
498
- """
499
- # only per row quantization is supported for bmm
500
395
granularity = PerRow ()
501
- dtype = torch .bfloat16
502
- device = "cuda"
503
-
504
- bmm_config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
505
- moe_config = MoEQuantConfig (bmm_config )
506
-
507
- batch_size = 4
508
- num_experts = 2
509
- input_dim = 64
510
- dim = 128
511
- hidden_dim = 256
512
-
513
- moe1 = Experts (num_experts , dim , hidden_dim , dtype , device )
514
- moe2 = Experts (num_experts , dim , hidden_dim , dtype , device )
515
- moe_combined = Experts (num_experts , dim , 2 * hidden_dim , dtype , device )
516
- input = torch .randn (batch_size , input_dim , dim , dtype = dtype , device = device )
517
-
518
- moes = [moe1 , moe2 ]
519
-
520
- for moe in moes :
521
- moe (input )
522
-
523
- def filter_fn (module , fqn ):
524
- return isinstance (module , Experts )
525
-
526
- # need to transpose before quantizing
527
- moe .w1 = torch .nn .Parameter (
528
- moe .w1 .transpose (1 , 2 ).contiguous (), requires_grad = False
529
- )
530
- moe .w2 = torch .nn .Parameter (
531
- moe .w2 .transpose (1 , 2 ).contiguous (), requires_grad = False
532
- )
533
- moe .w3 = torch .nn .Parameter (
534
- moe .w3 .transpose (1 , 2 ).contiguous (), requires_grad = False
535
- )
536
-
537
- quantize_ (moe , moe_config , filter_fn = filter_fn )
538
-
539
- # make sure it runs
540
- before = moe (input )
541
-
542
- # transposing for resharding support since only 2D resharding is supported
543
- new_last_dim = moe .w1 .shape [- 2 ]
544
- moe .w1 = torch .nn .Parameter (
545
- moe .w1 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
546
- )
547
- new_last_dim = moe .w2 .shape [- 2 ]
548
- moe .w2 = torch .nn .Parameter (
549
- moe .w2 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
550
- )
551
- new_last_dim = moe .w3 .shape [- 2 ]
552
- moe .w3 = torch .nn .Parameter (
553
- moe .w3 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
554
- )
555
-
556
- moe .w1 = torch .nn .Parameter (
557
- moe .w1 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
558
- requires_grad = False ,
559
- )
560
- moe .w2 = torch .nn .Parameter (
561
- moe .w2 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
562
- requires_grad = False ,
563
- )
564
- moe .w3 = torch .nn .Parameter (
565
- moe .w3 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
566
- requires_grad = False ,
567
- )
568
-
569
- # transpose again to recover the original weights
570
- moe .w1 = torch .nn .Parameter (moe .w1 .transpose (1 , 2 ), requires_grad = False )
571
- moe .w2 = torch .nn .Parameter (moe .w2 .transpose (1 , 2 ), requires_grad = False )
572
- moe .w3 = torch .nn .Parameter (moe .w3 .transpose (1 , 2 ), requires_grad = False )
573
-
574
- # make sure it runs
575
- after = moe (input )
576
-
577
- self .assertEqual (before , after )
578
-
579
- state_dicts = [moe1 .state_dict (), moe2 .state_dict ()]
580
- # align the scale parameter so they can be concatenated
581
- for key in ["w1" , "w2" , "w3" ]:
582
- weights = [st [key ] for st in state_dicts ]
583
- for i in range (1 , len (weights )):
584
- weights [i ].scale = weights [0 ].scale
585
-
586
- def process_key (key : str ) -> torch .Tensor :
587
- tensors = [s [key ] for s in state_dicts ]
588
- # Note: we have a hacky implementation for cat in user codebase
589
- # since it is not implemented correctly before
590
- if key == "w2" :
591
- return torch .cat (tensors , dim = - 1 )
592
- else :
593
- return torch .cat (tensors , dim = - 2 )
594
-
595
- new_state_dict = {}
596
- for key in ["w1" , "w2" , "w3" ]:
597
- new_state_dict [key ] = process_key (key )
598
-
599
- moe_combined .w1 = torch .nn .Parameter (
600
- moe_combined .w1 .transpose (1 , 2 ), requires_grad = False
601
- )
602
- moe_combined .w2 = torch .nn .Parameter (
603
- moe_combined .w2 .transpose (1 , 2 ), requires_grad = False
604
- )
605
- moe_combined .w3 = torch .nn .Parameter (
606
- moe_combined .w3 .transpose (1 , 2 ), requires_grad = False
607
- )
608
- moe_combined .load_state_dict (new_state_dict , assign = True )
609
- # make sure it runs
610
- moe_combined (input )
396
+ config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
397
+ self ._test_moe_weight_reshape_ops (config )
611
398
612
399
613
400
common_utils .instantiate_parametrized_tests (TestFloat8Tensor )
0 commit comments