@@ -127,6 +127,35 @@ class PeftLoraLoaderMixinTests:
127
127
text_encoder_target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
128
128
denoiser_target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ]
129
129
130
+ def test_simple_inference_save_pretrained_with_text_lora (self ):
131
+ """
132
+ Tests a simple usecase where users could use saving utilities for text encoder (only)
133
+ LoRA through save_pretrained.
134
+ """
135
+ if not any ("text_encoder" in k for k in self .pipeline_class ._lora_loadable_modules ):
136
+ pytest .skip ("Test not supported." )
137
+ for scheduler_cls in self .scheduler_classes :
138
+ pipe , inputs , _ , text_lora_config , denoiser_lora_config = self ._setup_pipeline_and_get_base_output (
139
+ scheduler_cls
140
+ )
141
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
142
+ images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
143
+
144
+ with tempfile .TemporaryDirectory () as tmpdirname :
145
+ pipe .save_pretrained (tmpdirname )
146
+ pipe_from_pretrained = self .pipeline_class .from_pretrained (tmpdirname )
147
+ pipe_from_pretrained .to (torch_device )
148
+ modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = False )
149
+
150
+ for module_name , module in modules_to_save .items ():
151
+ self .assertTrue (check_if_lora_correctly_set (module ), f"Lora not correctly set in { module_name } " )
152
+
153
+ images_lora_save_pretrained = pipe_from_pretrained (** inputs , generator = torch .manual_seed (0 ))[0 ]
154
+ self .assertTrue (
155
+ np .allclose (images_lora , images_lora_save_pretrained , atol = 1e-3 , rtol = 1e-3 ),
156
+ "Loading from saved checkpoints should give same results." ,
157
+ )
158
+
130
159
def test_low_cpu_mem_usage_with_injection (self ):
131
160
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
132
161
for scheduler_cls in self .scheduler_classes :
0 commit comments