Skip to content

Commit 2527917

Browse files
FIX set_lora_device when target layers differ (#11844)
* FIX set_lora_device when target layers differ Resolves #11833 Fixes a bug that occurs after calling set_lora_device when multiple LoRA adapters are loaded that target different layers. Note: Technically, the accompanying test does not require a GPU because the bug is triggered even if the parameters are already on the corresponding device, i.e. loading on CPU and then changing the device to CPU is sufficient to cause the bug. However, this may be optimized away in the future, so I decided to test with GPU. * Update docstring to warn about device mismatch * Extend docstring with an example * Fix docstring --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent e6639fe commit 2527917

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,27 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
934934
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
935935
you want to load multiple adapters and free some GPU memory.
936936
937+
After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
938+
can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
939+
GPU before using those LoRA adapters for inference.
940+
941+
```python
942+
>>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
943+
>>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
944+
>>> pipe.set_adapters("adapter-1")
945+
>>> image_1 = pipe(**kwargs)
946+
>>> # switch to adapter-2, offload adapter-1
947+
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
948+
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
949+
>>> pipe.set_adapters("adapter-2")
950+
>>> image_2 = pipe(**kwargs)
951+
>>> # switch back to adapter-1, offload adapter-2
952+
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
953+
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
954+
>>> pipe.set_adapters("adapter-1")
955+
>>> ...
956+
```
957+
937958
Args:
938959
adapter_names (`List[str]`):
939960
List of adapters to send device to.
@@ -949,6 +970,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
949970
for module in model.modules():
950971
if isinstance(module, BaseTunerLayer):
951972
for adapter_name in adapter_names:
973+
if adapter_name not in module.lora_A:
974+
# it is sufficient to check lora_A
975+
continue
976+
952977
module.lora_A[adapter_name].to(device)
953978
module.lora_B[adapter_name].to(device)
954979
# this is a param, not a module, so device placement is not in-place -> re-assign

tests/lora/test_lora_layers_sd.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_integration_move_lora_cpu(self):
120120

121121
self.assertTrue(
122122
check_if_lora_correctly_set(pipe.unet),
123-
"Lora not correctly set in text encoder",
123+
"Lora not correctly set in unet",
124124
)
125125

126126
# We will offload the first adapter in CPU and check if the offloading
@@ -187,7 +187,7 @@ def test_integration_move_lora_dora_cpu(self):
187187

188188
self.assertTrue(
189189
check_if_lora_correctly_set(pipe.unet),
190-
"Lora not correctly set in text encoder",
190+
"Lora not correctly set in unet",
191191
)
192192

193193
for name, param in pipe.unet.named_parameters():
@@ -208,6 +208,53 @@ def test_integration_move_lora_dora_cpu(self):
208208
if "lora_" in name:
209209
self.assertNotEqual(param.device, torch.device("cpu"))
210210

211+
@slow
212+
@require_torch_accelerator
213+
def test_integration_set_lora_device_different_target_layers(self):
214+
# fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different
215+
# layers, see #11833
216+
from peft import LoraConfig
217+
218+
path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
219+
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
220+
# configs partly target the same, partly different layers
221+
config0 = LoraConfig(target_modules=["to_k", "to_v"])
222+
config1 = LoraConfig(target_modules=["to_k", "to_q"])
223+
pipe.unet.add_adapter(config0, adapter_name="adapter-0")
224+
pipe.unet.add_adapter(config1, adapter_name="adapter-1")
225+
pipe = pipe.to(torch_device)
226+
227+
self.assertTrue(
228+
check_if_lora_correctly_set(pipe.unet),
229+
"Lora not correctly set in unet",
230+
)
231+
232+
# sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
233+
modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
234+
modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
235+
self.assertNotEqual(modules_adapter_0, modules_adapter_1)
236+
self.assertTrue(modules_adapter_0 - modules_adapter_1)
237+
self.assertTrue(modules_adapter_1 - modules_adapter_0)
238+
239+
# setting both separately works
240+
pipe.set_lora_device(["adapter-0"], "cpu")
241+
pipe.set_lora_device(["adapter-1"], "cpu")
242+
243+
for name, module in pipe.unet.named_modules():
244+
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
245+
self.assertTrue(module.weight.device == torch.device("cpu"))
246+
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
247+
self.assertTrue(module.weight.device == torch.device("cpu"))
248+
249+
# setting both at once also works
250+
pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
251+
252+
for name, module in pipe.unet.named_modules():
253+
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
254+
self.assertTrue(module.weight.device != torch.device("cpu"))
255+
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
256+
self.assertTrue(module.weight.device != torch.device("cpu"))
257+
211258

212259
@slow
213260
@nightly

0 commit comments

Comments
 (0)