Skip to content

Commit fe1af35

Browse files
committed
update
1 parent fc6fb85 commit fe1af35

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/lora/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,35 @@ class PeftLoraLoaderMixinTests:
127127
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
128128
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
129129

130+
def test_simple_inference_save_pretrained_with_text_lora(self):
131+
"""
132+
Tests a simple usecase where users could use saving utilities for text encoder (only)
133+
LoRA through save_pretrained.
134+
"""
135+
if not any("text_encoder" in k for k in self.pipeline_class._lora_loadable_modules):
136+
pytest.skip("Test not supported.")
137+
for scheduler_cls in self.scheduler_classes:
138+
pipe, inputs, _, text_lora_config, denoiser_lora_config = self._setup_pipeline_and_get_base_output(
139+
scheduler_cls
140+
)
141+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
142+
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
143+
144+
with tempfile.TemporaryDirectory() as tmpdirname:
145+
pipe.save_pretrained(tmpdirname)
146+
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
147+
pipe_from_pretrained.to(torch_device)
148+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=False)
149+
150+
for module_name, module in modules_to_save.items():
151+
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
152+
153+
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
154+
self.assertTrue(
155+
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
156+
"Loading from saved checkpoints should give same results.",
157+
)
158+
130159
def test_low_cpu_mem_usage_with_injection(self):
131160
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
132161
for scheduler_cls in self.scheduler_classes:

0 commit comments

Comments
 (0)