Skip to content

Commit 16f6e03

Browse files
Revert "Fix dtype issues of JAX CPU in SD3. (#2338)" (#2344)
This reverts commit ecef6d1.
1 parent ecef6d1 commit 16f6e03

File tree

1 file changed

+2
-17
lines changed

1 file changed

+2
-17
lines changed

keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,11 @@ def setUp(self):
3131
64,
3232
"quick_gelu",
3333
-2,
34-
# TODO: JAX CPU doesn't support float16 for
35-
# `nn.dot_product_attention`. We set dtype to float32 despite the
36-
# model defaulting to float16.
37-
dtype="float32",
34+
dtype="float16",
3835
name="clip_l",
3936
)
4037
clip_g = CLIPTextEncoder(
41-
20,
42-
64,
43-
64,
44-
2,
45-
2,
46-
128,
47-
"gelu",
48-
-2,
49-
# TODO: JAX CPU doesn't support float16 for
50-
# `nn.dot_product_attention`. We set dtype to float32 despite the
51-
# model defaulting to float16.
52-
dtype="float32",
53-
name="clip_g",
38+
20, 64, 64, 2, 2, 128, "gelu", -2, dtype="float16", name="clip_g"
5439
)
5540
self.init_kwargs = {
5641
"mmdit_patch_size": 2,

0 commit comments

Comments
 (0)