|
15 | 15 | import sys |
16 | 16 | import unittest |
17 | 17 |
|
| 18 | +import numpy as np |
18 | 19 | import torch |
19 | 20 | from transformers import AutoProcessor, Mistral3ForConditionalGeneration |
20 | 21 |
|
21 | 22 | from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel |
22 | 23 |
|
23 | | -from ..testing_utils import floats_tensor, require_peft_backend |
| 24 | +from ..testing_utils import floats_tensor, require_peft_backend, torch_device |
24 | 25 |
|
25 | 26 |
|
26 | 27 | sys.path.append(".") |
27 | 28 |
|
28 | | -from .utils import PeftLoraLoaderMixinTests # noqa: E402 |
| 29 | +from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 |
29 | 30 |
|
30 | 31 |
|
31 | 32 | @require_peft_backend |
@@ -94,6 +95,46 @@ def get_dummy_inputs(self, with_generator=True): |
94 | 95 |
|
95 | 96 | return noise, input_ids, pipeline_inputs |
96 | 97 |
|
| 98 | + # Overriding because (1) text encoder LoRAs are not supported in Flux 2 and (2) because the Flux 2 single block |
| 99 | + # QKV projections are always fused, it has no `to_q` param as expected by the original test. |
| 100 | + def test_lora_fuse_nan(self): |
| 101 | + components, _, denoiser_lora_config = self.get_dummy_components() |
| 102 | + pipe = self.pipeline_class(**components) |
| 103 | + pipe = pipe.to(torch_device) |
| 104 | + pipe.set_progress_bar_config(disable=None) |
| 105 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 106 | + |
| 107 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 108 | + denoiser.add_adapter(denoiser_lora_config, "adapter-1") |
| 109 | + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
| 110 | + |
| 111 | + # corrupt one LoRA weight with `inf` values |
| 112 | + with torch.no_grad(): |
| 113 | + possible_tower_names = ["transformer_blocks", "single_transformer_blocks"] |
| 114 | + filtered_tower_names = [ |
| 115 | + tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name) |
| 116 | + ] |
| 117 | + if len(filtered_tower_names) == 0: |
| 118 | + reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}." |
| 119 | + raise ValueError(reason) |
| 120 | + for tower_name in filtered_tower_names: |
| 121 | + transformer_tower = getattr(pipe.transformer, tower_name) |
| 122 | + is_single = "single" in tower_name |
| 123 | + if is_single: |
| 124 | + transformer_tower[0].attn.to_qkv_mlp_proj.lora_A["adapter-1"].weight += float("inf") |
| 125 | + else: |
| 126 | + transformer_tower[0].attn.to_k.lora_A["adapter-1"].weight += float("inf") |
| 127 | + |
| 128 | + # with `safe_fusing=True` we should see an Error |
| 129 | + with self.assertRaises(ValueError): |
| 130 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) |
| 131 | + |
| 132 | + # without we should not see an error, but every image will be black |
| 133 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) |
| 134 | + out = pipe(**inputs)[0] |
| 135 | + |
| 136 | + self.assertTrue(np.isnan(out).all()) |
| 137 | + |
97 | 138 | @unittest.skip("Not supported in Flux2.") |
98 | 139 | def test_simple_inference_with_text_denoiser_block_scale(self): |
99 | 140 | pass |
|
0 commit comments