From 4bd798302b23457bada2e3d4abb24bfa31583a47 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Aug 2025 08:56:01 +0530 Subject: [PATCH 1/8] start removing flax stuff. --- .../autoencoders/test_models_vae_flax.py | 39 - tests/models/test_modeling_common_flax.py | 66 -- .../models/unets/test_models_unet_2d_flax.py | 104 -- .../controlnet/test_flax_controlnet.py | 127 --- .../test_stable_diffusion_flax.py | 108 -- .../test_stable_diffusion_flax_inpaint.py | 82 -- tests/pipelines/test_pipelines_flax.py | 260 ----- tests/schedulers/test_scheduler_flax.py | 920 ------------------ 8 files changed, 1706 deletions(-) delete mode 100644 tests/models/autoencoders/test_models_vae_flax.py delete mode 100644 tests/models/test_modeling_common_flax.py delete mode 100644 tests/models/unets/test_models_unet_2d_flax.py delete mode 100644 tests/pipelines/controlnet/test_flax_controlnet.py delete mode 100644 tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py delete mode 100644 tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py delete mode 100644 tests/pipelines/test_pipelines_flax.py delete mode 100644 tests/schedulers/test_scheduler_flax.py diff --git a/tests/models/autoencoders/test_models_vae_flax.py b/tests/models/autoencoders/test_models_vae_flax.py deleted file mode 100644 index 8fedb85eccfc..000000000000 --- a/tests/models/autoencoders/test_models_vae_flax.py +++ /dev/null @@ -1,39 +0,0 @@ -import unittest - -from diffusers import FlaxAutoencoderKL -from diffusers.utils import is_flax_available -from diffusers.utils.testing_utils import require_flax - -from ..test_modeling_common_flax import FlaxModelTesterMixin - - -if is_flax_available(): - import jax - - -@require_flax -class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase): - model_class = FlaxAutoencoderKL - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - prng_key = jax.random.PRNGKey(0) - image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes)) - - return {"sample": image, "prng_key": prng_key} - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": [32, 64], - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "latent_channels": 4, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict diff --git a/tests/models/test_modeling_common_flax.py b/tests/models/test_modeling_common_flax.py deleted file mode 100644 index 8945aed7c93f..000000000000 --- a/tests/models/test_modeling_common_flax.py +++ /dev/null @@ -1,66 +0,0 @@ -import inspect - -from diffusers.utils import is_flax_available -from diffusers.utils.testing_utils import require_flax - - -if is_flax_available(): - import jax - - -@require_flax -class FlaxModelTesterMixin: - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) - jax.lax.stop_gradient(variables) - - output = model.apply(variables, inputs_dict["sample"]) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_forward_with_norm_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["norm_num_groups"] = 16 - init_dict["block_out_channels"] = (16, 32) - - model = self.model_class(**init_dict) - variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) - jax.lax.stop_gradient(variables) - - output = model.apply(variables, inputs_dict["sample"]) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_deprecated_kwargs(self): - has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters - has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 - - if has_kwarg_in_model_class and not has_deprecated_kwarg: - raise ValueError( - f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" - " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" - " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" - " []`" - ) - - if not has_kwarg_in_model_class and has_deprecated_kwarg: - raise ValueError( - f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" - " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" - f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" - " from `_deprecated_kwargs = []`" - ) diff --git a/tests/models/unets/test_models_unet_2d_flax.py b/tests/models/unets/test_models_unet_2d_flax.py deleted file mode 100644 index 69a0704dca9d..000000000000 --- a/tests/models/unets/test_models_unet_2d_flax.py +++ /dev/null @@ -1,104 +0,0 @@ -import gc -import unittest - -from parameterized import parameterized - -from diffusers import FlaxUNet2DConditionModel -from diffusers.utils import is_flax_available -from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow - - -if is_flax_available(): - import jax - import jax.numpy as jnp - - -@slow -@require_flax -class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase): - def get_file_format(self, seed, shape): - return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" - - def tearDown(self): - # clean up the VRAM after each test - super().tearDown() - gc.collect() - - def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): - dtype = jnp.bfloat16 if fp16 else jnp.float32 - image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) - return image - - def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): - dtype = jnp.bfloat16 if fp16 else jnp.float32 - revision = "bf16" if fp16 else None - - model, params = FlaxUNet2DConditionModel.from_pretrained( - model_id, subfolder="unet", dtype=dtype, revision=revision - ) - return model, params - - def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): - dtype = jnp.bfloat16 if fp16 else jnp.float32 - hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) - return hidden_states - - @parameterized.expand( - [ - # fmt: off - [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], - [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], - [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], - [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], - # fmt: on - ] - ) - def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice): - model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) - latents = self.get_latents(seed, fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) - - sample = model.apply( - {"params": params}, - latents, - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=encoder_hidden_states, - ).sample - - assert sample.shape == latents.shape - - output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) - expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) - - # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware - assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) - - @parameterized.expand( - [ - # fmt: off - [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], - [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], - [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], - [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], - # fmt: on - ] - ) - def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice): - model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) - latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) - - sample = model.apply( - {"params": params}, - latents, - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=encoder_hidden_states, - ).sample - - assert sample.shape == latents.shape - - output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) - expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) - - # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware - assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) diff --git a/tests/pipelines/controlnet/test_flax_controlnet.py b/tests/pipelines/controlnet/test_flax_controlnet.py deleted file mode 100644 index 07d3a09e5d27..000000000000 --- a/tests/pipelines/controlnet/test_flax_controlnet.py +++ /dev/null @@ -1,127 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc -import unittest - -from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline -from diffusers.utils import is_flax_available, load_image -from diffusers.utils.testing_utils import require_flax, slow - - -if is_flax_available(): - import jax - import jax.numpy as jnp - from flax.jax_utils import replicate - from flax.training.common_utils import shard - - -@slow -@require_flax -class FlaxControlNetPipelineIntegrationTests(unittest.TestCase): - def tearDown(self): - # clean up the VRAM after each test - super().tearDown() - gc.collect() - - def test_canny(self): - controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( - "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.bfloat16 - ) - pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16 - ) - params["controlnet"] = controlnet_params - - prompts = "bird" - num_samples = jax.device_count() - prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) - - canny_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" - ) - processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) - - rng = jax.random.PRNGKey(0) - rng = jax.random.split(rng, jax.device_count()) - - p_params = replicate(params) - prompt_ids = shard(prompt_ids) - processed_image = shard(processed_image) - - images = pipe( - prompt_ids=prompt_ids, - image=processed_image, - params=p_params, - prng_seed=rng, - num_inference_steps=50, - jit=True, - ).images - assert images.shape == (jax.device_count(), 1, 768, 512, 3) - - images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) - image_slice = images[0, 253:256, 253:256, -1] - - output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) - expected_slice = jnp.array( - [0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078] - ) - - assert jnp.abs(output_slice - expected_slice).max() < 1e-2 - - def test_pose(self): - controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( - "lllyasviel/sd-controlnet-openpose", from_pt=True, dtype=jnp.bfloat16 - ) - pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16 - ) - params["controlnet"] = controlnet_params - - prompts = "Chef in the kitchen" - num_samples = jax.device_count() - prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) - - pose_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png" - ) - processed_image = pipe.prepare_image_inputs([pose_image] * num_samples) - - rng = jax.random.PRNGKey(0) - rng = jax.random.split(rng, jax.device_count()) - - p_params = replicate(params) - prompt_ids = shard(prompt_ids) - processed_image = shard(processed_image) - - images = pipe( - prompt_ids=prompt_ids, - image=processed_image, - params=p_params, - prng_seed=rng, - num_inference_steps=50, - jit=True, - ).images - assert images.shape == (jax.device_count(), 1, 768, 512, 3) - - images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) - image_slice = images[0, 253:256, 253:256, -1] - - output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) - expected_slice = jnp.array( - [[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]] - ) - - assert jnp.abs(output_slice - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py deleted file mode 100644 index 77014bd7a518..000000000000 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py +++ /dev/null @@ -1,108 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc -import unittest - -from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline -from diffusers.utils import is_flax_available -from diffusers.utils.testing_utils import nightly, require_flax - - -if is_flax_available(): - import jax - import jax.numpy as jnp - from flax.jax_utils import replicate - from flax.training.common_utils import shard - - -@nightly -@require_flax -class FlaxStableDiffusion2PipelineIntegrationTests(unittest.TestCase): - def tearDown(self): - # clean up the VRAM after each test - super().tearDown() - gc.collect() - - def test_stable_diffusion_flax(self): - sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-2", - variant="bf16", - dtype=jnp.bfloat16, - ) - - prompt = "A painting of a squirrel eating a burger" - num_samples = jax.device_count() - prompt = num_samples * [prompt] - prompt_ids = sd_pipe.prepare_inputs(prompt) - - params = replicate(params) - prompt_ids = shard(prompt_ids) - - prng_seed = jax.random.PRNGKey(0) - prng_seed = jax.random.split(prng_seed, jax.device_count()) - - images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0] - assert images.shape == (jax.device_count(), 1, 768, 768, 3) - - images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) - image_slice = images[0, 253:256, 253:256, -1] - - output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) - expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512]) - - assert jnp.abs(output_slice - expected_slice).max() < 1e-2 - - -@nightly -@require_flax -class FlaxStableDiffusion2PipelineNightlyTests(unittest.TestCase): - def tearDown(self): - # clean up the VRAM after each test - super().tearDown() - gc.collect() - - def test_stable_diffusion_dpm_flax(self): - model_id = "stabilityai/stable-diffusion-2" - scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler") - sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained( - model_id, - scheduler=scheduler, - variant="bf16", - dtype=jnp.bfloat16, - ) - params["scheduler"] = scheduler_params - - prompt = "A painting of a squirrel eating a burger" - num_samples = jax.device_count() - prompt = num_samples * [prompt] - prompt_ids = sd_pipe.prepare_inputs(prompt) - - params = replicate(params) - prompt_ids = shard(prompt_ids) - - prng_seed = jax.random.PRNGKey(0) - prng_seed = jax.random.split(prng_seed, jax.device_count()) - - images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0] - assert images.shape == (jax.device_count(), 1, 768, 768, 3) - - images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) - image_slice = images[0, 253:256, 253:256, -1] - - output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) - expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297]) - - assert jnp.abs(output_slice - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py deleted file mode 100644 index d83c69673676..000000000000 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py +++ /dev/null @@ -1,82 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc -import unittest - -from diffusers import FlaxStableDiffusionInpaintPipeline -from diffusers.utils import is_flax_available, load_image -from diffusers.utils.testing_utils import require_flax, slow - - -if is_flax_available(): - import jax - import jax.numpy as jnp - from flax.jax_utils import replicate - from flax.training.common_utils import shard - - -@slow -@require_flax -class FlaxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): - def tearDown(self): - # clean up the VRAM after each test - super().tearDown() - gc.collect() - - def test_stable_diffusion_inpaint_pipeline(self): - init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/sd2-inpaint/init_image.png" - ) - mask_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png" - ) - - model_id = "xvjiarui/stable-diffusion-2-inpainting" - pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None) - - prompt = "Face of a yellow cat, high resolution, sitting on a park bench" - - prng_seed = jax.random.PRNGKey(0) - num_inference_steps = 50 - - num_samples = jax.device_count() - prompt = num_samples * [prompt] - init_image = num_samples * [init_image] - mask_image = num_samples * [mask_image] - prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(prompt, init_image, mask_image) - - # shard inputs and rng - params = replicate(params) - prng_seed = jax.random.split(prng_seed, jax.device_count()) - prompt_ids = shard(prompt_ids) - processed_masked_images = shard(processed_masked_images) - processed_masks = shard(processed_masks) - - output = pipeline( - prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True - ) - - images = output.images.reshape(num_samples, 512, 512, 3) - - image_slice = images[0, 253:256, 253:256, -1] - - output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) - expected_slice = jnp.array( - [0.3611307, 0.37649736, 0.3757408, 0.38213953, 0.39295167, 0.3841631, 0.41554978, 0.4137475, 0.4217084] - ) - - assert jnp.abs(output_slice - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/test_pipelines_flax.py b/tests/pipelines/test_pipelines_flax.py deleted file mode 100644 index ffe43ac9d76d..000000000000 --- a/tests/pipelines/test_pipelines_flax.py +++ /dev/null @@ -1,260 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest - -import numpy as np - -from diffusers.utils import is_flax_available -from diffusers.utils.testing_utils import require_flax, slow - - -if is_flax_available(): - import jax - import jax.numpy as jnp - from flax.jax_utils import replicate - from flax.training.common_utils import shard - - from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline - - -@require_flax -class DownloadTests(unittest.TestCase): - def test_download_only_pytorch(self): - with tempfile.TemporaryDirectory() as tmpdirname: - # pipeline has Flax weights - _ = FlaxDiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname - ) - - all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] - files = [item for sublist in all_root_files for item in sublist] - - # None of the downloaded files should be a PyTorch file even if we have some here: - # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin - assert not any(f.endswith(".bin") for f in files) - - -@slow -@require_flax -class FlaxPipelineTests(unittest.TestCase): - def test_dummy_all_tpus(self): - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None - ) - - prompt = ( - "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" - " field, close up, split lighting, cinematic" - ) - - prng_seed = jax.random.PRNGKey(0) - num_inference_steps = 4 - - num_samples = jax.device_count() - prompt = num_samples * [prompt] - prompt_ids = pipeline.prepare_inputs(prompt) - - # shard inputs and rng - params = replicate(params) - prng_seed = jax.random.split(prng_seed, num_samples) - prompt_ids = shard(prompt_ids) - - images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images - - assert images.shape == (num_samples, 1, 64, 64, 3) - if jax.device_count() == 8: - assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.1514745) < 1e-3 - assert np.abs(np.abs(images, dtype=np.float32).sum() - 49947.875) < 5e-1 - - images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) - assert len(images_pil) == num_samples - - def test_stable_diffusion_v1_4(self): - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None - ) - - prompt = ( - "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" - " field, close up, split lighting, cinematic" - ) - - prng_seed = jax.random.PRNGKey(0) - num_inference_steps = 50 - - num_samples = jax.device_count() - prompt = num_samples * [prompt] - prompt_ids = pipeline.prepare_inputs(prompt) - - # shard inputs and rng - params = replicate(params) - prng_seed = jax.random.split(prng_seed, num_samples) - prompt_ids = shard(prompt_ids) - - images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images - - assert images.shape == (num_samples, 1, 512, 512, 3) - if jax.device_count() == 8: - assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-2 - assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1 - - def test_stable_diffusion_v1_4_bfloat_16(self): - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", variant="bf16", dtype=jnp.bfloat16, safety_checker=None - ) - - prompt = ( - "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" - " field, close up, split lighting, cinematic" - ) - - prng_seed = jax.random.PRNGKey(0) - num_inference_steps = 50 - - num_samples = jax.device_count() - prompt = num_samples * [prompt] - prompt_ids = pipeline.prepare_inputs(prompt) - - # shard inputs and rng - params = replicate(params) - prng_seed = jax.random.split(prng_seed, num_samples) - prompt_ids = shard(prompt_ids) - - images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images - - assert images.shape == (num_samples, 1, 512, 512, 3) - if jax.device_count() == 8: - assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2 - assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1 - - def test_stable_diffusion_v1_4_bfloat_16_with_safety(self): - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", variant="bf16", dtype=jnp.bfloat16 - ) - - prompt = ( - "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" - " field, close up, split lighting, cinematic" - ) - - prng_seed = jax.random.PRNGKey(0) - num_inference_steps = 50 - - num_samples = jax.device_count() - prompt = num_samples * [prompt] - prompt_ids = pipeline.prepare_inputs(prompt) - - # shard inputs and rng - params = replicate(params) - prng_seed = jax.random.split(prng_seed, num_samples) - prompt_ids = shard(prompt_ids) - - images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images - - assert images.shape == (num_samples, 1, 512, 512, 3) - if jax.device_count() == 8: - assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2 - assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1 - - def test_stable_diffusion_v1_4_bfloat_16_ddim(self): - scheduler = FlaxDDIMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - set_alpha_to_one=False, - steps_offset=1, - ) - - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", - variant="bf16", - dtype=jnp.bfloat16, - scheduler=scheduler, - safety_checker=None, - ) - scheduler_state = scheduler.create_state() - - params["scheduler"] = scheduler_state - - prompt = ( - "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" - " field, close up, split lighting, cinematic" - ) - - prng_seed = jax.random.PRNGKey(0) - num_inference_steps = 50 - - num_samples = jax.device_count() - prompt = num_samples * [prompt] - prompt_ids = pipeline.prepare_inputs(prompt) - - # shard inputs and rng - params = replicate(params) - prng_seed = jax.random.split(prng_seed, num_samples) - prompt_ids = shard(prompt_ids) - - images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images - - assert images.shape == (num_samples, 1, 512, 512, 3) - if jax.device_count() == 8: - assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 5e-2 - assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1 - - def test_jax_memory_efficient_attention(self): - prompt = ( - "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" - " field, close up, split lighting, cinematic" - ) - - num_samples = jax.device_count() - prompt = num_samples * [prompt] - prng_seed = jax.random.split(jax.random.PRNGKey(0), num_samples) - - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", - variant="bf16", - dtype=jnp.bfloat16, - safety_checker=None, - ) - - params = replicate(params) - prompt_ids = pipeline.prepare_inputs(prompt) - prompt_ids = shard(prompt_ids) - images = pipeline(prompt_ids, params, prng_seed, jit=True).images - assert images.shape == (num_samples, 1, 512, 512, 3) - slice = images[2, 0, 256, 10:17, 1] - - # With memory efficient attention - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", - variant="bf16", - dtype=jnp.bfloat16, - safety_checker=None, - use_memory_efficient_attention=True, - ) - - params = replicate(params) - prompt_ids = pipeline.prepare_inputs(prompt) - prompt_ids = shard(prompt_ids) - images_eff = pipeline(prompt_ids, params, prng_seed, jit=True).images - assert images_eff.shape == (num_samples, 1, 512, 512, 3) - slice_eff = images[2, 0, 256, 10:17, 1] - - # I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum` - # over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now. - assert abs(slice_eff - slice).max() < 1e-2 diff --git a/tests/schedulers/test_scheduler_flax.py b/tests/schedulers/test_scheduler_flax.py deleted file mode 100644 index c8121d334164..000000000000 --- a/tests/schedulers/test_scheduler_flax.py +++ /dev/null @@ -1,920 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -import tempfile -import unittest -from typing import Dict, List, Tuple - -from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler -from diffusers.utils import is_flax_available -from diffusers.utils.testing_utils import require_flax - - -if is_flax_available(): - import jax - import jax.numpy as jnp - from jax import random - - jax_device = jax.default_backend() - - -@require_flax -class FlaxSchedulerCommonTest(unittest.TestCase): - scheduler_classes = () - forward_default_kwargs = () - - @property - def dummy_sample(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - key1, key2 = random.split(random.PRNGKey(0)) - sample = random.uniform(key1, (batch_size, num_channels, height, width)) - - return sample, key2 - - @property - def dummy_sample_deter(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - num_elems = batch_size * num_channels * height * width - sample = jnp.arange(num_elems) - sample = sample.reshape(num_channels, height, width, batch_size) - sample = sample / num_elems - return jnp.transpose(sample, (3, 0, 1, 2)) - - def get_scheduler_config(self): - raise NotImplementedError - - def dummy_model(self): - def model(sample, t, *args): - return sample * t / (t + 1) - - return model - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, key = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - kwargs.update(forward_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, key = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_from_save_pretrained(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, key = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, key = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample - output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, key = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs) - - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def test_deprecated_kwargs(self): - for scheduler_class in self.scheduler_classes: - has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters - has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 - - if has_kwarg_in_model_class and not has_deprecated_kwarg: - raise ValueError( - f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" - " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" - " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" - " []`" - ) - - if not has_kwarg_in_model_class and has_deprecated_kwarg: - raise ValueError( - f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" - " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" - f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" - " deprecated argument from `_deprecated_kwargs = []`" - ) - - -@require_flax -class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxDDPMScheduler,) - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - "variance_type": "fixed_small", - "clip_sample": True, - } - - config.update(**kwargs) - return config - - def test_timesteps(self): - for timesteps in [1, 5, 100, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - - def test_schedules(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - - def test_variance_type(self): - for variance in ["fixed_small", "fixed_large", "other"]: - self.check_over_configs(variance_type=variance) - - def test_clip_sample(self): - for clip_sample in [True, False]: - self.check_over_configs(clip_sample=clip_sample) - - def test_time_indices(self): - for t in [0, 500, 999]: - self.check_over_forward(time_step=t) - - def test_variance(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5 - - def test_full_loop_no_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_trained_timesteps = len(scheduler) - - model = self.dummy_model() - sample = self.dummy_sample_deter - key1, key2 = random.split(random.PRNGKey(0)) - - for t in reversed(range(num_trained_timesteps)): - # 1. predict noise residual - residual = model(sample, t) - - # 2. predict previous mean of sample x_t-1 - output = scheduler.step(state, residual, t, sample, key1) - pred_prev_sample = output.prev_sample - state = output.state - key1, key2 = random.split(key2) - - # if t > 0: - # noise = self.dummy_sample_deter - # variance = scheduler.get_variance(t) ** (0.5) * noise - # - # sample = pred_prev_sample + variance - sample = pred_prev_sample - - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 255.0714) < 1e-2 - assert abs(result_mean - 0.332124) < 1e-3 - else: - assert abs(result_sum - 270.2) < 1e-1 - assert abs(result_mean - 0.3519494) < 1e-3 - - -@require_flax -class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxDDIMScheduler,) - forward_default_kwargs = (("num_inference_steps", 50),) - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - } - - config.update(**kwargs) - return config - - def full_loop(self, **config): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - key1, key2 = random.split(random.PRNGKey(0)) - - num_inference_steps = 10 - - model = self.dummy_model() - sample = self.dummy_sample_deter - - state = scheduler.set_timesteps(state, num_inference_steps) - - for t in state.timesteps: - residual = model(sample, t) - output = scheduler.step(state, residual, t, sample) - sample = output.prev_sample - state = output.state - key1, key2 = random.split(key2) - - return sample - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_from_save_pretrained(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - kwargs.update(forward_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) - - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample - output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_timesteps(self): - for timesteps in [100, 500, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_steps_offset(self): - for steps_offset in [0, 1]: - self.check_over_configs(steps_offset=steps_offset) - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(steps_offset=1) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 5) - assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all() - - def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - - def test_schedules(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - - def test_time_indices(self): - for t in [1, 10, 49]: - self.check_over_forward(time_step=t) - - def test_inference_steps(self): - for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): - self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) - - def test_variance(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5 - - def test_full_loop_no_noise(self): - sample = self.full_loop() - - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_sum - 172.0067) < 1e-2 - assert abs(result_mean - 0.223967) < 1e-3 - - def test_full_loop_with_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 149.8409) < 1e-2 - assert abs(result_mean - 0.1951) < 1e-3 - else: - assert abs(result_sum - 149.8295) < 1e-2 - assert abs(result_mean - 0.1951) < 1e-3 - - def test_full_loop_with_no_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - pass - # FIXME: both result_sum and result_mean are nan on TPU - # assert jnp.isnan(result_sum) - # assert jnp.isnan(result_mean) - else: - assert abs(result_sum - 149.0784) < 1e-2 - assert abs(result_mean - 0.1941) < 1e-3 - - def test_prediction_type(self): - for prediction_type in ["epsilon", "sample", "v_prediction"]: - self.check_over_configs(prediction_type=prediction_type) - - -@require_flax -class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxPNDMScheduler,) - forward_default_kwargs = (("num_inference_steps", 50),) - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - } - - config.update(**kwargs) - return config - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample, _ = self.dummy_sample - residual = 0.1 * sample - dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - # copy over dummy past residuals - state = state.replace(ets=dummy_past_residuals[:]) - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) - # copy over dummy past residuals - new_state = new_state.replace(ets=dummy_past_residuals[:]) - - (prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs) - (new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical" - - output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) - new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - @unittest.skip("Test not supported.") - def test_from_save_pretrained(self): - pass - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) - - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample, _ = self.dummy_sample - residual = 0.1 * sample - dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - - # copy over dummy past residuals (must be after setting timesteps) - scheduler.ets = dummy_past_residuals[:] - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - # copy over dummy past residuals - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) - - # copy over dummy past residual (must be after setting timesteps) - new_state.replace(ets=dummy_past_residuals[:]) - - output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs) - new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) - new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def full_loop(self, **config): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - - for i, t in enumerate(state.prk_timesteps): - residual = model(sample, t) - sample, state = scheduler.step_prk(state, residual, t, sample) - - for i, t in enumerate(state.plms_timesteps): - residual = model(sample, t) - sample, state = scheduler.step_plms(state, residual, t, sample) - - return sample - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - # copy over dummy past residuals (must be done after set_timesteps) - dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) - state = state.replace(ets=dummy_past_residuals[:]) - - output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs) - output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs) - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs) - output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs) - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_timesteps(self): - for timesteps in [100, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_steps_offset(self): - for steps_offset in [0, 1]: - self.check_over_configs(steps_offset=steps_offset) - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(steps_offset=1) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 10, shape=()) - assert jnp.equal( - state.timesteps, - jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), - ).all() - - def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - - def test_schedules(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - - def test_time_indices(self): - for t in [1, 5, 10]: - self.check_over_forward(time_step=t) - - def test_inference_steps(self): - for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): - self.check_over_forward(num_inference_steps=num_inference_steps) - - def test_pow_of_3_inference_steps(self): - # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 - num_inference_steps = 27 - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - - # before power of 3 fix, would error on first step, so we only need to do two - for i, t in enumerate(state.prk_timesteps[:2]): - sample, state = scheduler.step_prk(state, residual, t, sample) - - def test_inference_plms_no_past_residuals(self): - with self.assertRaises(ValueError): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample - - def test_full_loop_no_noise(self): - sample = self.full_loop() - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 198.1275) < 1e-2 - assert abs(result_mean - 0.2580) < 1e-3 - else: - assert abs(result_sum - 198.1318) < 1e-2 - assert abs(result_mean - 0.2580) < 1e-3 - - def test_full_loop_with_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 186.83226) < 1e-2 - assert abs(result_mean - 0.24327) < 1e-3 - else: - assert abs(result_sum - 186.9466) < 1e-2 - assert abs(result_mean - 0.24342) < 1e-3 - - def test_full_loop_with_no_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 186.83226) < 1e-2 - assert abs(result_mean - 0.24327) < 1e-3 - else: - assert abs(result_sum - 186.9482) < 1e-2 - assert abs(result_mean - 0.2434) < 1e-3 From d1b1bcd837f5206c6100c6d035b5431a6602c61e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Aug 2025 09:01:17 +0530 Subject: [PATCH 2/8] add deprecation warning. --- src/diffusers/models/modeling_flax_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py index 010b7377451c..915b3fc13503 100644 --- a/src/diffusers/models/modeling_flax_utils.py +++ b/src/diffusers/models/modeling_flax_utils.py @@ -290,6 +290,10 @@ def from_pretrained( You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) From 585fed4a9363b43e152b6e898ecd1ad372cd2258 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Aug 2025 09:03:52 +0530 Subject: [PATCH 3/8] add warning messages. --- src/diffusers/pipelines/pipeline_flax_utils.py | 5 +++++ src/diffusers/schedulers/scheduling_utils_flax.py | 8 +++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py index ea2c0763d93a..fa1ff9a38a37 100644 --- a/src/diffusers/pipelines/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -312,6 +312,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P >>> dpm_params["scheduler"] = dpmpp_state ``` """ + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + cache_dir = kwargs.pop("cache_dir", None) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index e6ac78f63ee7..59f3d563734a 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -22,9 +22,11 @@ import jax.numpy as jnp from huggingface_hub.utils import validate_hf_hub_args -from ..utils import BaseOutput, PushToHubMixin +from ..utils import BaseOutput, PushToHubMixin, logging +logger = logging.get_logger(__name__) + SCHEDULER_CONFIG_NAME = "scheduler_config.json" @@ -133,6 +135,10 @@ def from_pretrained( """ + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) config, kwargs = cls.load_config( pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, From 044d75aab30a413b5da02a6b5d7cdb880fdde9f4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Aug 2025 09:19:30 +0530 Subject: [PATCH 4/8] more warnings. --- src/diffusers/models/attention_flax.py | 30 +++++++++++ .../models/controlnets/controlnet_flax.py | 15 +++++- src/diffusers/models/embeddings_flax.py | 15 ++++++ src/diffusers/models/resnet_flax.py | 20 +++++++ .../models/unets/unet_2d_blocks_flax.py | 29 ++++++++++ .../models/unets/unet_2d_condition_flax.py | 10 +++- src/diffusers/models/vae_flax.py | 54 ++++++++++++++++++- 7 files changed, 170 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 17e6f33df051..015603f968bf 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -19,6 +19,11 @@ import jax import jax.numpy as jnp +from ..utils import logging + + +logger = logging.get_logger(__name__) + def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): """Multi-head dot product attention with a limited number of queries.""" @@ -151,6 +156,11 @@ class FlaxAttention(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + inner_dim = self.dim_head * self.heads self.scale = self.dim_head**-0.5 @@ -277,6 +287,11 @@ class FlaxBasicTransformerBlock(nn.Module): split_head_dim: bool = False def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + # self attention (or cross_attention if only_cross_attention is True) self.attn1 = FlaxAttention( self.dim, @@ -365,6 +380,11 @@ class FlaxTransformer2DModel(nn.Module): split_head_dim: bool = False def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) inner_dim = self.n_heads * self.d_head @@ -454,6 +474,11 @@ class FlaxFeedForward(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + # The second linear layer needs to be called # net_2 for now to match the index of the Sequential layer self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) @@ -484,6 +509,11 @@ class FlaxGEGLU(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + inner_dim = self.dim * 4 self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.dropout) diff --git a/src/diffusers/models/controlnets/controlnet_flax.py b/src/diffusers/models/controlnets/controlnet_flax.py index 4b2148666ebf..7bbbd6eafa24 100644 --- a/src/diffusers/models/controlnets/controlnet_flax.py +++ b/src/diffusers/models/controlnets/controlnet_flax.py @@ -20,7 +20,7 @@ from flax.core.frozen_dict import FrozenDict from ...configuration_utils import ConfigMixin, flax_register_to_config -from ...utils import BaseOutput +from ...utils import BaseOutput, logging from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from ..modeling_flax_utils import FlaxModelMixin from ..unets.unet_2d_blocks_flax import ( @@ -30,6 +30,9 @@ ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class FlaxControlNetOutput(BaseOutput): """ @@ -50,6 +53,11 @@ class FlaxControlNetConditioningEmbedding(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self) -> None: + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + self.conv_in = nn.Conv( self.block_out_channels[0], kernel_size=(3, 3), @@ -184,6 +192,11 @@ def init_weights(self, rng: jax.Array) -> FrozenDict: return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"] def setup(self) -> None: + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 1e7e84edeaeb..3b2a089c3c06 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -16,6 +16,11 @@ import flax.linen as nn import jax.numpy as jnp +from ..utils import logging + + +logger = logging.get_logger(__name__) + def get_sinusoidal_embeddings( timesteps: jnp.ndarray, @@ -76,6 +81,11 @@ class FlaxTimestepEmbedding(nn.Module): The data type for the embedding parameters. """ + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + time_embed_dim: int = 32 dtype: jnp.dtype = jnp.float32 @@ -104,6 +114,11 @@ class FlaxTimesteps(nn.Module): flip_sin_to_cos: bool = False freq_shift: float = 1 + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + @nn.compact def __call__(self, timesteps): return get_sinusoidal_embeddings( diff --git a/src/diffusers/models/resnet_flax.py b/src/diffusers/models/resnet_flax.py index 9c80932c5c5d..84fb8653989b 100644 --- a/src/diffusers/models/resnet_flax.py +++ b/src/diffusers/models/resnet_flax.py @@ -15,12 +15,22 @@ import jax import jax.numpy as jnp +from ..utils import logging + + +logger = logging.get_logger(__name__) + class FlaxUpsample2D(nn.Module): out_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + self.conv = nn.Conv( self.out_channels, kernel_size=(3, 3), @@ -45,6 +55,11 @@ class FlaxDownsample2D(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + self.conv = nn.Conv( self.out_channels, kernel_size=(3, 3), @@ -68,6 +83,11 @@ class FlaxResnetBlock2D(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + out_channels = self.in_channels if self.out_channels is None else self.out_channels self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) diff --git a/src/diffusers/models/unets/unet_2d_blocks_flax.py b/src/diffusers/models/unets/unet_2d_blocks_flax.py index abd025165ecf..ed5307ef98e9 100644 --- a/src/diffusers/models/unets/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unets/unet_2d_blocks_flax.py @@ -15,10 +15,14 @@ import flax.linen as nn import jax.numpy as jnp +from ...utils import logging from ..attention_flax import FlaxTransformer2DModel from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D +logger = logging.get_logger(__name__) + + class FlaxCrossAttnDownBlock2D(nn.Module): r""" Cross Attention 2D Downsizing block - original architecture from Unet transformers: @@ -60,6 +64,11 @@ class FlaxCrossAttnDownBlock2D(nn.Module): transformer_layers_per_block: int = 1 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + resnets = [] attentions = [] @@ -135,6 +144,11 @@ class FlaxDownBlock2D(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + resnets = [] for i in range(self.num_layers): @@ -208,6 +222,11 @@ class FlaxCrossAttnUpBlock2D(nn.Module): transformer_layers_per_block: int = 1 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + resnets = [] attentions = [] @@ -288,6 +307,11 @@ class FlaxUpBlock2D(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + resnets = [] for i in range(self.num_layers): @@ -356,6 +380,11 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): transformer_layers_per_block: int = 1 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + # there is always at least one resnet resnets = [ FlaxResnetBlock2D( diff --git a/src/diffusers/models/unets/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py index 7c21ddb690ae..132b52c8c2ad 100644 --- a/src/diffusers/models/unets/unet_2d_condition_flax.py +++ b/src/diffusers/models/unets/unet_2d_condition_flax.py @@ -20,7 +20,7 @@ from flax.core.frozen_dict import FrozenDict from ...configuration_utils import ConfigMixin, flax_register_to_config -from ...utils import BaseOutput +from ...utils import BaseOutput, logging from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from ..modeling_flax_utils import FlaxModelMixin from .unet_2d_blocks_flax import ( @@ -32,6 +32,9 @@ ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class FlaxUNet2DConditionOutput(BaseOutput): """ @@ -163,6 +166,11 @@ def init_weights(self, rng: jax.Array) -> FrozenDict: return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"] def setup(self) -> None: + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 93398a51eac7..968ef6e55ff7 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -25,10 +25,13 @@ from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, logging from .modeling_flax_utils import FlaxModelMixin +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class FlaxDecoderOutput(BaseOutput): """ @@ -73,6 +76,10 @@ class FlaxUpsample2D(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), @@ -107,6 +114,11 @@ class FlaxDownsample2D(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), @@ -149,6 +161,11 @@ class FlaxResnetBlock2D(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + out_channels = self.in_channels if self.out_channels is None else self.out_channels self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) @@ -221,6 +238,11 @@ class FlaxAttentionBlock(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 dense = partial(nn.Dense, self.channels, dtype=self.dtype) @@ -302,6 +324,11 @@ class FlaxDownEncoderBlock2D(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels @@ -359,6 +386,11 @@ class FlaxUpDecoderBlock2D(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels @@ -413,6 +445,11 @@ class FlaxUNetMidBlock2D(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) # there is always at least one resnet @@ -504,6 +541,11 @@ class FlaxEncoder(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + block_out_channels = self.block_out_channels # in self.conv_in = nn.Conv( @@ -616,6 +658,11 @@ class FlaxDecoder(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + block_out_channels = self.block_out_channels # z to block_in @@ -788,6 +835,11 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): dtype: jnp.dtype = jnp.float32 def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + self.encoder = FlaxEncoder( in_channels=self.config.in_channels, out_channels=self.config.latent_channels, From 50a8c71c59d55d9973bec1e1451a621a3008de10 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Aug 2025 09:22:05 +0530 Subject: [PATCH 5/8] remove dockerfiles. --- docker/diffusers-flax-cpu/Dockerfile | 49 -------------------------- docker/diffusers-flax-tpu/Dockerfile | 51 ---------------------------- 2 files changed, 100 deletions(-) delete mode 100644 docker/diffusers-flax-cpu/Dockerfile delete mode 100644 docker/diffusers-flax-tpu/Dockerfile diff --git a/docker/diffusers-flax-cpu/Dockerfile b/docker/diffusers-flax-cpu/Dockerfile deleted file mode 100644 index 051008aa9a2e..000000000000 --- a/docker/diffusers-flax-cpu/Dockerfile +++ /dev/null @@ -1,49 +0,0 @@ -FROM ubuntu:20.04 -LABEL maintainer="Hugging Face" -LABEL repository="diffusers" - -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt-get -y update \ - && apt-get install -y software-properties-common \ - && add-apt-repository ppa:deadsnakes/ppa - -RUN apt install -y bash \ - build-essential \ - git \ - git-lfs \ - curl \ - ca-certificates \ - libsndfile1-dev \ - libgl1 \ - python3.10 \ - python3-pip \ - python3.10-venv && \ - rm -rf /var/lib/apt/lists - -# make sure to use venv -RUN python3.10 -m venv /opt/venv -ENV PATH="/opt/venv/bin:$PATH" - -# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) -# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container -RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ - python3 -m uv pip install --upgrade --no-cache-dir \ - clu \ - "jax[cpu]>=0.2.16,!=0.3.2" \ - "flax>=0.4.1" \ - "jaxlib>=0.1.65" && \ - python3 -m uv pip install --no-cache-dir \ - accelerate \ - datasets \ - hf-doc-builder \ - huggingface-hub \ - Jinja2 \ - librosa \ - numpy==1.26.4 \ - scipy \ - tensorboard \ - transformers \ - hf_transfer - -CMD ["/bin/bash"] \ No newline at end of file diff --git a/docker/diffusers-flax-tpu/Dockerfile b/docker/diffusers-flax-tpu/Dockerfile deleted file mode 100644 index 405f068923b7..000000000000 --- a/docker/diffusers-flax-tpu/Dockerfile +++ /dev/null @@ -1,51 +0,0 @@ -FROM ubuntu:20.04 -LABEL maintainer="Hugging Face" -LABEL repository="diffusers" - -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt-get -y update \ - && apt-get install -y software-properties-common \ - && add-apt-repository ppa:deadsnakes/ppa - -RUN apt install -y bash \ - build-essential \ - git \ - git-lfs \ - curl \ - ca-certificates \ - libsndfile1-dev \ - libgl1 \ - python3.10 \ - python3-pip \ - python3.10-venv && \ - rm -rf /var/lib/apt/lists - -# make sure to use venv -RUN python3.10 -m venv /opt/venv -ENV PATH="/opt/venv/bin:$PATH" - -# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) -# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container -RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ - python3 -m pip install --no-cache-dir \ - "jax[tpu]>=0.2.16,!=0.3.2" \ - -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \ - python3 -m uv pip install --upgrade --no-cache-dir \ - clu \ - "flax>=0.4.1" \ - "jaxlib>=0.1.65" && \ - python3 -m uv pip install --no-cache-dir \ - accelerate \ - datasets \ - hf-doc-builder \ - huggingface-hub \ - Jinja2 \ - librosa \ - numpy==1.26.4 \ - scipy \ - tensorboard \ - transformers \ - hf_transfer - -CMD ["/bin/bash"] \ No newline at end of file From a3601c33465acdfa0be3c3440b1b0d4ac0702cfd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 15 Aug 2025 09:24:42 +0530 Subject: [PATCH 6/8] remove more. --- .github/workflows/pr_flax_dependency_test.yml | 38 ------------------- 1 file changed, 38 deletions(-) delete mode 100644 .github/workflows/pr_flax_dependency_test.yml diff --git a/.github/workflows/pr_flax_dependency_test.yml b/.github/workflows/pr_flax_dependency_test.yml deleted file mode 100644 index e091b5f2d7b3..000000000000 --- a/.github/workflows/pr_flax_dependency_test.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: Run Flax dependency tests - -on: - pull_request: - branches: - - main - paths: - - "src/diffusers/**.py" - push: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - check_flax_dependencies: - runs-on: ubuntu-22.04 - steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.8" - - name: Install dependencies - run: | - python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m pip install --upgrade pip uv - python -m uv pip install -e . - python -m uv pip install "jax[cpu]>=0.2.16,!=0.3.2" - python -m uv pip install "flax>=0.4.1" - python -m uv pip install "jaxlib>=0.1.65" - python -m uv pip install pytest - - name: Check for soft dependencies - run: | - python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - pytest tests/others/test_dependencies.py From 0605c836f7504565ddc3053f561abf6433811cfc Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 22 Aug 2025 22:00:26 +0530 Subject: [PATCH 7/8] Update src/diffusers/models/attention_flax.py Co-authored-by: Dhruv Nair --- src/diffusers/models/attention_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 015603f968bf..c23f6a0b2df5 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -157,7 +157,7 @@ class FlaxAttention(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0 We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) From 85378eb46b4cd5398c1c7aeedd48890c78cc3805 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 22 Aug 2025 22:02:21 +0530 Subject: [PATCH 8/8] up --- src/diffusers/models/attention_flax.py | 10 +++++----- .../models/controlnets/controlnet_flax.py | 4 ++-- src/diffusers/models/embeddings_flax.py | 4 ++-- src/diffusers/models/modeling_flax_utils.py | 2 +- src/diffusers/models/resnet_flax.py | 6 +++--- .../models/unets/unet_2d_blocks_flax.py | 10 +++++----- .../models/unets/unet_2d_condition_flax.py | 2 +- src/diffusers/models/vae_flax.py | 20 +++++++++---------- .../pipelines/pipeline_flax_utils.py | 2 +- .../schedulers/scheduling_utils_flax.py | 2 +- 10 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index c23f6a0b2df5..1bde62e5c666 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -157,7 +157,7 @@ class FlaxAttention(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1.0.0 We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -288,7 +288,7 @@ class FlaxBasicTransformerBlock(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -381,7 +381,7 @@ class FlaxTransformer2DModel(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -475,7 +475,7 @@ class FlaxFeedForward(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -510,7 +510,7 @@ class FlaxGEGLU(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) diff --git a/src/diffusers/models/controlnets/controlnet_flax.py b/src/diffusers/models/controlnets/controlnet_flax.py index 7bbbd6eafa24..f7a8b98fa2f0 100644 --- a/src/diffusers/models/controlnets/controlnet_flax.py +++ b/src/diffusers/models/controlnets/controlnet_flax.py @@ -54,7 +54,7 @@ class FlaxControlNetConditioningEmbedding(nn.Module): def setup(self) -> None: logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -193,7 +193,7 @@ def init_weights(self, rng: jax.Array) -> FrozenDict: def setup(self) -> None: logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 3b2a089c3c06..3790905e583c 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -82,7 +82,7 @@ class FlaxTimestepEmbedding(nn.Module): """ logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -115,7 +115,7 @@ class FlaxTimesteps(nn.Module): freq_shift: float = 1 logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py index 915b3fc13503..573828dc4b03 100644 --- a/src/diffusers/models/modeling_flax_utils.py +++ b/src/diffusers/models/modeling_flax_utils.py @@ -291,7 +291,7 @@ def from_pretrained( ``` """ logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) config = kwargs.pop("config", None) diff --git a/src/diffusers/models/resnet_flax.py b/src/diffusers/models/resnet_flax.py index 84fb8653989b..9bedaa9a36b6 100644 --- a/src/diffusers/models/resnet_flax.py +++ b/src/diffusers/models/resnet_flax.py @@ -27,7 +27,7 @@ class FlaxUpsample2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -56,7 +56,7 @@ class FlaxDownsample2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -84,7 +84,7 @@ class FlaxResnetBlock2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) diff --git a/src/diffusers/models/unets/unet_2d_blocks_flax.py b/src/diffusers/models/unets/unet_2d_blocks_flax.py index ed5307ef98e9..6e6005afdc31 100644 --- a/src/diffusers/models/unets/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unets/unet_2d_blocks_flax.py @@ -65,7 +65,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -145,7 +145,7 @@ class FlaxDownBlock2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -223,7 +223,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -308,7 +308,7 @@ class FlaxUpBlock2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -381,7 +381,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) diff --git a/src/diffusers/models/unets/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py index 132b52c8c2ad..8d9a309afbcc 100644 --- a/src/diffusers/models/unets/unet_2d_condition_flax.py +++ b/src/diffusers/models/unets/unet_2d_condition_flax.py @@ -167,7 +167,7 @@ def init_weights(self, rng: jax.Array) -> FrozenDict: def setup(self) -> None: logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 968ef6e55ff7..13653b90372a 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -77,7 +77,7 @@ class FlaxUpsample2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) self.conv = nn.Conv( @@ -115,7 +115,7 @@ class FlaxDownsample2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -162,7 +162,7 @@ class FlaxResnetBlock2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -239,7 +239,7 @@ class FlaxAttentionBlock(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -325,7 +325,7 @@ class FlaxDownEncoderBlock2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -387,7 +387,7 @@ class FlaxUpDecoderBlock2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -446,7 +446,7 @@ class FlaxUNetMidBlock2D(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -542,7 +542,7 @@ class FlaxEncoder(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -659,7 +659,7 @@ class FlaxDecoder(nn.Module): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) @@ -836,7 +836,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): def setup(self): logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py index fa1ff9a38a37..f69968022ed7 100644 --- a/src/diffusers/pipelines/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -313,7 +313,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ``` """ logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index 59f3d563734a..ffbe3b90207b 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -136,7 +136,7 @@ def from_pretrained( """ logger.warning( - "Flax classes are deprecated and will be removed in Diffusers v1. We " + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " "recommend migrating to PyTorch classes or pinning your version of Diffusers." ) config, kwargs = cls.load_config(