Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 0 additions & 38 deletions .github/workflows/pr_flax_dependency_test.yml

This file was deleted.

49 changes: 0 additions & 49 deletions docker/diffusers-flax-cpu/Dockerfile

This file was deleted.

51 changes: 0 additions & 51 deletions docker/diffusers-flax-tpu/Dockerfile

This file was deleted.

30 changes: 30 additions & 0 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import jax
import jax.numpy as jnp

from ..utils import logging


logger = logging.get_logger(__name__)


def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
"""Multi-head dot product attention with a limited number of queries."""
Expand Down Expand Up @@ -151,6 +156,11 @@ class FlaxAttention(nn.Module):
dtype: jnp.dtype = jnp.float32

def setup(self):
logger.warning(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought of adding it to all the public Flax classes, can remove.

"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)

inner_dim = self.dim_head * self.heads
self.scale = self.dim_head**-0.5

Expand Down Expand Up @@ -277,6 +287,11 @@ class FlaxBasicTransformerBlock(nn.Module):
split_head_dim: bool = False

def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)

# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention(
self.dim,
Expand Down Expand Up @@ -365,6 +380,11 @@ class FlaxTransformer2DModel(nn.Module):
split_head_dim: bool = False

def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)

self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)

inner_dim = self.n_heads * self.d_head
Expand Down Expand Up @@ -454,6 +474,11 @@ class FlaxFeedForward(nn.Module):
dtype: jnp.dtype = jnp.float32

def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)

# The second linear layer needs to be called
# net_2 for now to match the index of the Sequential layer
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
Expand Down Expand Up @@ -484,6 +509,11 @@ class FlaxGEGLU(nn.Module):
dtype: jnp.dtype = jnp.float32

def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)

inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.dropout)
Expand Down
15 changes: 14 additions & 1 deletion src/diffusers/models/controlnets/controlnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from flax.core.frozen_dict import FrozenDict

from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...utils import BaseOutput
from ...utils import BaseOutput, logging
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from ..unets.unet_2d_blocks_flax import (
Expand All @@ -30,6 +30,9 @@
)


logger = logging.get_logger(__name__)


@flax.struct.dataclass
class FlaxControlNetOutput(BaseOutput):
"""
Expand All @@ -50,6 +53,11 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
dtype: jnp.dtype = jnp.float32

def setup(self) -> None:
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)

self.conv_in = nn.Conv(
self.block_out_channels[0],
kernel_size=(3, 3),
Expand Down Expand Up @@ -184,6 +192,11 @@ def init_weights(self, rng: jax.Array) -> FrozenDict:
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]

def setup(self) -> None:
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)

block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4

Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/models/embeddings_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
import flax.linen as nn
import jax.numpy as jnp

from ..utils import logging


logger = logging.get_logger(__name__)


def get_sinusoidal_embeddings(
timesteps: jnp.ndarray,
Expand Down Expand Up @@ -76,6 +81,11 @@ class FlaxTimestepEmbedding(nn.Module):
The data type for the embedding parameters.
"""

logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)

time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32

Expand Down Expand Up @@ -104,6 +114,11 @@ class FlaxTimesteps(nn.Module):
flip_sin_to_cos: bool = False
freq_shift: float = 1

logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)

@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/models/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ def from_pretrained(
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
Expand Down
20 changes: 20 additions & 0 deletions src/diffusers/models/resnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,22 @@
import jax
import jax.numpy as jnp

from ..utils import logging


logger = logging.get_logger(__name__)


class FlaxUpsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32

def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)

self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
Expand All @@ -45,6 +55,11 @@ class FlaxDownsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32

def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)

self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
Expand All @@ -68,6 +83,11 @@ class FlaxResnetBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32

def setup(self):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)

out_channels = self.in_channels if self.out_channels is None else self.out_channels

self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
Expand Down
Loading
Loading