Skip to content

Commit 39faf5f

Browse files
committed
user property instead
1 parent 2f3df7d commit 39faf5f

File tree

4 files changed

+30
-20
lines changed

4 files changed

+30
-20
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -866,15 +866,17 @@ def test_fp4_double_safe(self):
866866

867867
@require_torch_version_greater("2.7.1")
868868
class Bnb4BitCompileTests(QuantCompileTests):
869-
quantization_config = PipelineQuantizationConfig(
870-
quant_backend="bitsandbytes_8bit",
871-
quant_kwargs={
872-
"load_in_4bit": True,
873-
"bnb_4bit_quant_type": "nf4",
874-
"bnb_4bit_compute_dtype": torch.bfloat16,
875-
},
876-
components_to_quantize=["transformer", "text_encoder_2"],
877-
)
869+
@property
870+
def quantization_config(self):
871+
return PipelineQuantizationConfig(
872+
quant_backend="bitsandbytes_8bit",
873+
quant_kwargs={
874+
"load_in_4bit": True,
875+
"bnb_4bit_quant_type": "nf4",
876+
"bnb_4bit_compute_dtype": torch.bfloat16,
877+
},
878+
components_to_quantize=["transformer", "text_encoder_2"],
879+
)
878880

879881
def test_torch_compile(self):
880882
torch._dynamo.config.capture_dynamic_output_shape_ops = True

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -831,11 +831,13 @@ def test_serialization_sharded(self):
831831

832832
@require_torch_version_greater_equal("2.6.0")
833833
class Bnb8BitCompileTests(QuantCompileTests):
834-
quantization_config = PipelineQuantizationConfig(
835-
quant_backend="bitsandbytes_8bit",
836-
quant_kwargs={"load_in_8bit": True},
837-
components_to_quantize=["transformer", "text_encoder_2"],
838-
)
834+
@property
835+
def quantization_config(self):
836+
return PipelineQuantizationConfig(
837+
quant_backend="bitsandbytes_8bit",
838+
quant_kwargs={"load_in_8bit": True},
839+
components_to_quantize=["transformer", "text_encoder_2"],
840+
)
839841

840842
def test_torch_compile(self):
841843
torch._dynamo.config.capture_dynamic_output_shape_ops = True

tests/quantization/test_torch_compile_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
@require_torch_gpu
2525
@slow
2626
class QuantCompileTests(unittest.TestCase):
27-
quantization_config = None
27+
@property
28+
def quantization_config(self):
29+
raise NotImplementedError(
30+
"This property should be implemented in the subclass to return the appropriate quantization config."
31+
)
2832

2933
def setUp(self):
3034
super().setUp()

tests/quantization/torchao/test_torchao.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -631,11 +631,13 @@ def test_int_a16w8_cpu(self):
631631

632632
@require_torchao_version_greater_or_equal("0.7.0")
633633
class TorchAoCompileTest(QuantCompileTests):
634-
quantization_config = PipelineQuantizationConfig(
635-
quant_mapping={
636-
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
637-
},
638-
)
634+
@property
635+
def quantization_config(self):
636+
return PipelineQuantizationConfig(
637+
quant_mapping={
638+
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
639+
},
640+
)
639641

640642
def test_torch_compile(self):
641643
super()._test_torch_compile(quantization_config=self.quantization_config)

0 commit comments

Comments
 (0)