Skip to content

Commit 754fe85

Browse files
sayakpaulDN6
andauthored
[tests] add compile + offload tests for GGUF. (#11740)
* add compile + offload tests for GGUF. * quality * add init. * prop. * change to flux. --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent cc1f9a2 commit 754fe85

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

tests/quantization/gguf/__init__.py

Whitespace-only changes.

tests/quantization/gguf/test_gguf.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from diffusers import (
99
AuraFlowPipeline,
1010
AuraFlowTransformer2DModel,
11+
DiffusionPipeline,
1112
FluxControlPipeline,
1213
FluxPipeline,
1314
FluxTransformer2DModel,
@@ -32,9 +33,12 @@
3233
require_big_accelerator,
3334
require_gguf_version_greater_or_equal,
3435
require_peft_backend,
36+
require_torch_version_greater,
3537
torch_device,
3638
)
3739

40+
from ..test_torch_compile_utils import QuantCompileTests
41+
3842

3943
if is_gguf_available():
4044
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
@@ -647,3 +651,31 @@ def get_dummy_inputs(self):
647651
).to(torch_device, self.torch_dtype),
648652
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
649653
}
654+
655+
656+
@require_torch_version_greater("2.7.1")
657+
class GGUFCompileTests(QuantCompileTests):
658+
torch_dtype = torch.bfloat16
659+
gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
660+
661+
@property
662+
def quantization_config(self):
663+
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
664+
665+
def test_torch_compile(self):
666+
super()._test_torch_compile(quantization_config=self.quantization_config)
667+
668+
def test_torch_compile_with_cpu_offload(self):
669+
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
670+
671+
def test_torch_compile_with_group_offload_leaf(self):
672+
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
673+
674+
def _init_pipeline(self, *args, **kwargs):
675+
transformer = FluxTransformer2DModel.from_single_file(
676+
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
677+
)
678+
pipe = DiffusionPipeline.from_pretrained(
679+
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=self.torch_dtype
680+
)
681+
return pipe

0 commit comments

Comments
 (0)