File tree Expand file tree Collapse file tree 4 files changed +30
-20
lines changed Expand file tree Collapse file tree 4 files changed +30
-20
lines changed Original file line number Diff line number Diff line change @@ -866,15 +866,17 @@ def test_fp4_double_safe(self):
866
866
867
867
@require_torch_version_greater ("2.7.1" )
868
868
class Bnb4BitCompileTests (QuantCompileTests ):
869
- quantization_config = PipelineQuantizationConfig (
870
- quant_backend = "bitsandbytes_8bit" ,
871
- quant_kwargs = {
872
- "load_in_4bit" : True ,
873
- "bnb_4bit_quant_type" : "nf4" ,
874
- "bnb_4bit_compute_dtype" : torch .bfloat16 ,
875
- },
876
- components_to_quantize = ["transformer" , "text_encoder_2" ],
877
- )
869
+ @property
870
+ def quantization_config (self ):
871
+ return PipelineQuantizationConfig (
872
+ quant_backend = "bitsandbytes_8bit" ,
873
+ quant_kwargs = {
874
+ "load_in_4bit" : True ,
875
+ "bnb_4bit_quant_type" : "nf4" ,
876
+ "bnb_4bit_compute_dtype" : torch .bfloat16 ,
877
+ },
878
+ components_to_quantize = ["transformer" , "text_encoder_2" ],
879
+ )
878
880
879
881
def test_torch_compile (self ):
880
882
torch ._dynamo .config .capture_dynamic_output_shape_ops = True
Original file line number Diff line number Diff line change @@ -831,11 +831,13 @@ def test_serialization_sharded(self):
831
831
832
832
@require_torch_version_greater_equal ("2.6.0" )
833
833
class Bnb8BitCompileTests (QuantCompileTests ):
834
- quantization_config = PipelineQuantizationConfig (
835
- quant_backend = "bitsandbytes_8bit" ,
836
- quant_kwargs = {"load_in_8bit" : True },
837
- components_to_quantize = ["transformer" , "text_encoder_2" ],
838
- )
834
+ @property
835
+ def quantization_config (self ):
836
+ return PipelineQuantizationConfig (
837
+ quant_backend = "bitsandbytes_8bit" ,
838
+ quant_kwargs = {"load_in_8bit" : True },
839
+ components_to_quantize = ["transformer" , "text_encoder_2" ],
840
+ )
839
841
840
842
def test_torch_compile (self ):
841
843
torch ._dynamo .config .capture_dynamic_output_shape_ops = True
Original file line number Diff line number Diff line change 24
24
@require_torch_gpu
25
25
@slow
26
26
class QuantCompileTests (unittest .TestCase ):
27
- quantization_config = None
27
+ @property
28
+ def quantization_config (self ):
29
+ raise NotImplementedError (
30
+ "This property should be implemented in the subclass to return the appropriate quantization config."
31
+ )
28
32
29
33
def setUp (self ):
30
34
super ().setUp ()
Original file line number Diff line number Diff line change @@ -631,11 +631,13 @@ def test_int_a16w8_cpu(self):
631
631
632
632
@require_torchao_version_greater_or_equal ("0.7.0" )
633
633
class TorchAoCompileTest (QuantCompileTests ):
634
- quantization_config = PipelineQuantizationConfig (
635
- quant_mapping = {
636
- "transformer" : TorchAoConfig (quant_type = "int8_weight_only" ),
637
- },
638
- )
634
+ @property
635
+ def quantization_config (self ):
636
+ return PipelineQuantizationConfig (
637
+ quant_mapping = {
638
+ "transformer" : TorchAoConfig (quant_type = "int8_weight_only" ),
639
+ },
640
+ )
639
641
640
642
def test_torch_compile (self ):
641
643
super ()._test_torch_compile (quantization_config = self .quantization_config )
You can’t perform that action at this time.
0 commit comments