Skip to content

Commit f20b83a

Browse files
enable cpu offloading of new pipelines on XPU & use device agnostic empty to make pipelines work on XPU (#11671)
* commit 1 Signed-off-by: YAO Matrix <[email protected]> * patch 2 Signed-off-by: YAO Matrix <[email protected]> * Update pipeline_pag_sana.py * Update pipeline_sana.py * Update pipeline_sana_controlnet.py * Update pipeline_sana_sprint_img2img.py * Update pipeline_sana_sprint.py * fix style Signed-off-by: YAO Matrix <[email protected]> * fix fat-thumb while merge conflict Signed-off-by: YAO Matrix <[email protected]> * fix ci issues Signed-off-by: YAO Matrix <[email protected]> --------- Signed-off-by: YAO Matrix <[email protected]> Co-authored-by: Ilyas Moutawwakil <[email protected]>
1 parent ee40088 commit f20b83a

33 files changed

+127
-80
lines changed

src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
replace_example_docstring,
4242
)
4343
from ...utils.import_utils import is_transformers_version
44-
from ...utils.torch_utils import randn_tensor
44+
from ...utils.torch_utils import empty_device_cache, randn_tensor
4545
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
4646
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
4747

@@ -267,9 +267,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
267267

268268
if self.device.type != "cpu":
269269
self.to("cpu", silence_dtype_warnings=True)
270-
device_mod = getattr(torch, device.type, None)
271-
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
272-
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
270+
empty_device_cache(device.type)
273271

274272
model_sequence = [
275273
self.text_encoder.text_model,

src/diffusers/pipelines/consisid/consisid_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def prepare_face_models(model_path, device, dtype):
294294
295295
Parameters:
296296
- model_path: Path to the directory containing model files.
297-
- device: The device (e.g., 'cuda', 'cpu') where models will be loaded.
297+
- device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded.
298298
- dtype: Data type (e.g., torch.float32) for model inference.
299299
300300
Returns:

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
scale_lora_layers,
3838
unscale_lora_layers,
3939
)
40-
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
40+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
4141
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4242
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
4343
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -1339,7 +1339,7 @@ def __call__(
13391339
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
13401340
self.unet.to("cpu")
13411341
self.controlnet.to("cpu")
1342-
torch.cuda.empty_cache()
1342+
empty_device_cache()
13431343

13441344
if not output_type == "latent":
13451345
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
scale_lora_layers,
3737
unscale_lora_layers,
3838
)
39-
from ...utils.torch_utils import is_compiled_module, randn_tensor
39+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
4040
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4141
from ..stable_diffusion import StableDiffusionPipelineOutput
4242
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -1311,7 +1311,7 @@ def __call__(
13111311
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
13121312
self.unet.to("cpu")
13131313
self.controlnet.to("cpu")
1314-
torch.cuda.empty_cache()
1314+
empty_device_cache()
13151315

13161316
if not output_type == "latent":
13171317
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
scale_lora_layers,
3939
unscale_lora_layers,
4040
)
41-
from ...utils.torch_utils import is_compiled_module, randn_tensor
41+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
4242
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4343
from ..stable_diffusion import StableDiffusionPipelineOutput
4444
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -1500,7 +1500,7 @@ def __call__(
15001500
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
15011501
self.unet.to("cpu")
15021502
self.controlnet.to("cpu")
1503-
torch.cuda.empty_cache()
1503+
empty_device_cache()
15041504

15051505
if not output_type == "latent":
15061506
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
scale_lora_layers,
5252
unscale_lora_layers,
5353
)
54-
from ...utils.torch_utils import is_compiled_module, randn_tensor
54+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
5555
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
5656
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
5757

@@ -1858,7 +1858,7 @@ def denoising_value_valid(dnv):
18581858
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
18591859
self.unet.to("cpu")
18601860
self.controlnet.to("cpu")
1861-
torch.cuda.empty_cache()
1861+
empty_device_cache()
18621862

18631863
if not output_type == "latent":
18641864
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1465,7 +1465,11 @@ def __call__(
14651465

14661466
# Relevant thread:
14671467
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1468-
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
1468+
if (
1469+
torch.cuda.is_available()
1470+
and (is_unet_compiled and is_controlnet_compiled)
1471+
and is_torch_higher_equal_2_1
1472+
):
14691473
torch._inductor.cudagraph_mark_step_begin()
14701474
# expand the latents if we are doing classifier free guidance
14711475
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
scale_lora_layers,
5454
unscale_lora_layers,
5555
)
56-
from ...utils.torch_utils import is_compiled_module, randn_tensor
56+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
5757
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
5858
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
5959

@@ -921,7 +921,7 @@ def prepare_latents(
921921
# Offload text encoder if `enable_model_cpu_offload` was enabled
922922
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
923923
self.text_encoder_2.to("cpu")
924-
torch.cuda.empty_cache()
924+
empty_device_cache()
925925

926926
image = image.to(device=device, dtype=dtype)
927927

@@ -1632,7 +1632,7 @@ def __call__(
16321632
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
16331633
self.unet.to("cpu")
16341634
self.controlnet.to("cpu")
1635-
torch.cuda.empty_cache()
1635+
empty_device_cache()
16361636

16371637
if not output_type == "latent":
16381638
# make sure the VAE is in float32 mode, as it overflows in float16

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
scale_lora_layers,
5252
unscale_lora_layers,
5353
)
54-
from ...utils.torch_utils import is_compiled_module, randn_tensor
54+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
5555
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
5656
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
5757

@@ -1766,7 +1766,7 @@ def denoising_value_valid(dnv):
17661766
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
17671767
self.unet.to("cpu")
17681768
self.controlnet.to("cpu")
1769-
torch.cuda.empty_cache()
1769+
empty_device_cache()
17701770

17711771
if not output_type == "latent":
17721772
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
scale_lora_layers,
5454
unscale_lora_layers,
5555
)
56-
from ...utils.torch_utils import is_compiled_module, randn_tensor
56+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
5757
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
5858
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
5959

@@ -876,7 +876,7 @@ def prepare_latents(
876876
# Offload text encoder if `enable_model_cpu_offload` was enabled
877877
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
878878
self.text_encoder_2.to("cpu")
879-
torch.cuda.empty_cache()
879+
empty_device_cache()
880880

881881
image = image.to(device=device, dtype=dtype)
882882

@@ -1574,7 +1574,7 @@ def __call__(
15741574
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
15751575
self.unet.to("cpu")
15761576
self.controlnet.to("cpu")
1577-
torch.cuda.empty_cache()
1577+
empty_device_cache()
15781578

15791579
if not output_type == "latent":
15801580
# make sure the VAE is in float32 mode, as it overflows in float16

0 commit comments

Comments
 (0)