|
8 | 8 | from diffusers import (
|
9 | 9 | AuraFlowPipeline,
|
10 | 10 | AuraFlowTransformer2DModel,
|
| 11 | + DiffusionPipeline, |
11 | 12 | FluxControlPipeline,
|
12 | 13 | FluxPipeline,
|
13 | 14 | FluxTransformer2DModel,
|
|
32 | 33 | require_big_accelerator,
|
33 | 34 | require_gguf_version_greater_or_equal,
|
34 | 35 | require_peft_backend,
|
| 36 | + require_torch_version_greater, |
35 | 37 | torch_device,
|
36 | 38 | )
|
37 | 39 |
|
| 40 | +from ..test_torch_compile_utils import QuantCompileTests |
| 41 | + |
38 | 42 |
|
39 | 43 | if is_gguf_available():
|
40 | 44 | from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
|
@@ -647,3 +651,31 @@ def get_dummy_inputs(self):
|
647 | 651 | ).to(torch_device, self.torch_dtype),
|
648 | 652 | "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
649 | 653 | }
|
| 654 | + |
| 655 | + |
| 656 | +@require_torch_version_greater("2.7.1") |
| 657 | +class GGUFCompileTests(QuantCompileTests): |
| 658 | + torch_dtype = torch.bfloat16 |
| 659 | + gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" |
| 660 | + |
| 661 | + @property |
| 662 | + def quantization_config(self): |
| 663 | + return GGUFQuantizationConfig(compute_dtype=self.torch_dtype) |
| 664 | + |
| 665 | + def test_torch_compile(self): |
| 666 | + super()._test_torch_compile(quantization_config=self.quantization_config) |
| 667 | + |
| 668 | + def test_torch_compile_with_cpu_offload(self): |
| 669 | + super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) |
| 670 | + |
| 671 | + def test_torch_compile_with_group_offload_leaf(self): |
| 672 | + super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config) |
| 673 | + |
| 674 | + def _init_pipeline(self, *args, **kwargs): |
| 675 | + transformer = FluxTransformer2DModel.from_single_file( |
| 676 | + self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype |
| 677 | + ) |
| 678 | + pipe = DiffusionPipeline.from_pretrained( |
| 679 | + "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=self.torch_dtype |
| 680 | + ) |
| 681 | + return pipe |
0 commit comments