Skip to content

Commit 2c608d1

Browse files
committed
update
1 parent b69d099 commit 2c608d1

File tree

4 files changed

+25
-39
lines changed

4 files changed

+25
-39
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
require_peft_backend,
4646
require_torch,
4747
require_torch_accelerator,
48-
require_torch_version_greater,
4948
require_transformers_version_greater,
5049
slow,
5150
torch_device,
@@ -861,7 +860,7 @@ def test_fp4_double_safe(self):
861860
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
862861

863862

864-
@require_torch_version_greater("2.7.1")
863+
# @require_torch_version_greater("2.7.1")
865864
class Bnb4BitCompileTests(QuantCompileTests):
866865
quantization_config = PipelineQuantizationConfig(
867866
quant_backend="bitsandbytes_8bit",
@@ -880,5 +879,7 @@ def test_torch_compile(self):
880879
def test_torch_compile_with_cpu_offload(self):
881880
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
882881

883-
def test_torch_compile_with_group_offload(self):
884-
super()._test_torch_compile_with_group_offload_leaf_stream(quantization_config=self.quantization_config)
882+
def test_torch_compile_with_group_offload_leaf(self):
883+
super()._test_torch_compile_with_group_offload_leaf(
884+
quantization_config=self.quantization_config, use_stream=True
885+
)

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def test_torch_compile_with_cpu_offload(self):
844844
)
845845

846846
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
847-
def test_torch_compile_with_group_offload(self):
848-
super()._test_torch_compile_with_group_offload_leaf_stream(
849-
quantization_config=self.quantization_config, torch_dtype=torch.float16
847+
def test_torch_compile_with_group_offload_leaf(self):
848+
super()._test_torch_compile_with_group_offload_leaf(
849+
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
850850
)

tests/quantization/test_torch_compile_utils.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=
6464
# small resolutions to ensure speedy execution.
6565
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
6666

67-
def _test_torch_compile_with_group_offload_leaf(self, quantization_config, torch_dtype=torch.bfloat16):
67+
def _test_torch_compile_with_group_offload_leaf(
68+
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
69+
):
6870
torch._dynamo.config.cache_size_limit = 10000
6971

7072
pipe = self._init_pipeline(quantization_config, torch_dtype)
@@ -73,28 +75,7 @@ def _test_torch_compile_with_group_offload_leaf(self, quantization_config, torch
7375
"offload_device": torch.device("cpu"),
7476
"offload_type": "leaf_level",
7577
"num_blocks_per_group": 1,
76-
"use_stream": False,
77-
}
78-
pipe.transformer.enable_group_offload(**group_offload_kwargs)
79-
pipe.transformer.compile()
80-
for name, component in pipe.components.items():
81-
if name != "transformer" and isinstance(component, torch.nn.Module):
82-
if torch.device(component.device).type == "cpu":
83-
component.to("cuda")
84-
85-
for _ in range(2):
86-
# small resolutions to ensure speedy execution.
87-
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
88-
89-
def _test_torch_compile_with_group_offload_leaf_stream(self, quantization_config, torch_dtype=torch.bfloat16):
90-
torch._dynamo.config.cache_size_limit = 10000
91-
92-
pipe = self._init_pipeline(quantization_config, torch_dtype)
93-
group_offload_kwargs = {
94-
"onload_device": torch.device("cuda"),
95-
"offload_device": torch.device("cpu"),
96-
"offload_type": "leaf_level",
97-
"use_stream": True,
78+
"use_stream": use_stream,
9879
}
9980
pipe.transformer.enable_group_offload(**group_offload_kwargs)
10081
pipe.transformer.compile()

tests/quantization/torchao/test_torchao.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import List
2020

2121
import numpy as np
22+
from parameterized import parameterized
2223
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
2324

2425
from diffusers import (
@@ -648,10 +649,17 @@ def test_torch_compile_with_cpu_offload(self):
648649
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
649650

650651
@unittest.skip(
651-
"Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation "
652-
"is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure."
652+
"""
653+
For `use_stream=False`:
654+
- Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
655+
is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
656+
For `use_stream=True`:
657+
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
658+
"""
653659
)
660+
@parameterized.expand([False, True])
654661
def test_torch_compile_with_group_offload_leaf(self):
662+
# For use_stream=False:
655663
# If we run group offloading without compilation, we will see:
656664
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
657665
# When running with compilation, the error ends up being different:
@@ -660,14 +668,10 @@ def test_torch_compile_with_group_offload_leaf(self):
660668
# Looks like something that will have to be looked into upstream.
661669
# for linear layers, weight.tensor_impl shows cuda... but:
662670
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
663-
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
664671

665-
@unittest.skip(
666-
"Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO."
667-
)
668-
def test_torch_compile_with_group_offload_leaf_stream(self):
669-
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
670-
super()._test_torch_compile_with_group_offload_leaf_stream(quantization_config=self.quantization_config)
672+
# For use_stream=True:
673+
# # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
674+
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
671675

672676

673677
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners

0 commit comments

Comments
 (0)