Skip to content

Commit 8adc600

Browse files
EdnaordinaryDN6github-actions[bot]
authored
Chroma Pipeline (#11698)
* working state from hameerabbasi and iddl * working state form hameerabbasi and iddl (transformer) * working state (normalization) * working state (embeddings) * add chroma loader * add chroma to mappings * add chroma to transformer init * take out variant stuff * get decently far in changing variant stuff * add chroma init * make chroma output class * add chroma transformer to dummy tp * add chroma to init * add chroma to init * fix single file * update * update * add chroma to auto pipeline * add chroma to pipeline init * change to chroma transformer * take out variant from blocks * swap embedder location * remove prompt_2 * work on swapping text encoders * remove mask function * dont modify mask (for now) * wrap attn mask * no attn mask (can't get it to work) * remove pooled prompt embeds * change to my own unpooled embeddeer * fix load * take pooled projections out of transformer * ensure correct dtype for chroma embeddings * update * use dn6 attn mask + fix true_cfg_scale * use chroma pipeline output * use DN6 embeddings * remove guidance * remove guidance embed (pipeline) * remove guidance from embeddings * don't return length * dont change dtype * remove unused stuff, fix up docs * add chroma autodoc * add .md (oops) * initial chroma docs * undo don't change dtype * undo arxiv change unsure why that happened * fix hf papers regression in more places * Update docs/source/en/api/pipelines/chroma.md Co-authored-by: Dhruv Nair <[email protected]> * do_cfg -> self.do_classifier_free_guidance * Update docs/source/en/api/models/chroma_transformer.md Co-authored-by: Dhruv Nair <[email protected]> * Update chroma.md * Move chroma layers into transformer * Remove pruned AdaLayerNorms * Add chroma fast tests * (untested) batch cond and uncond * Add # Copied from for shift * Update # Copied from statements * update norm imports * Revert cond + uncond batching * Add transformer tests * move chroma test (oops) * chroma init * fix chroma pipeline fast tests * Update src/diffusers/models/transformers/transformer_chroma.py Co-authored-by: Dhruv Nair <[email protected]> * Move Approximator and Embeddings * Fix auto pipeline + make style, quality * make style * Apply style fixes * switch to new input ids * fix # Copied from error * remove # Copied from on protected members * try to fix import * fix import * make fix-copes * revert style fix * update chroma transformer params * update chroma transformer approximator init params * update to pad tokens * fix batch inference * Make more pipeline tests work * Make most transformer tests work * fix docs * make style, make quality * skip batch tests * fix test skipping * fix test skipping again * fix for tests * Fix all pipeline test * update * push local changes, fix docs * add encoder test, remove pooled dim * default proj dim * fix tests * fix equal size list input * update * push local changes, fix docs * add encoder test, remove pooled dim * default proj dim * fix tests * fix equal size list input * Revert "fix equal size list input" This reverts commit 3fe4ad6. * update * update * update * update * update --------- Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 9f91305 commit 8adc600

23 files changed

+2336
-5
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@
283283
title: AllegroTransformer3DModel
284284
- local: api/models/aura_flow_transformer2d
285285
title: AuraFlowTransformer2DModel
286+
- local: api/models/chroma_transformer
287+
title: ChromaTransformer2DModel
286288
- local: api/models/cogvideox_transformer3d
287289
title: CogVideoXTransformer3DModel
288290
- local: api/models/cogview3plus_transformer2d
@@ -405,6 +407,8 @@
405407
title: AutoPipeline
406408
- local: api/pipelines/blip_diffusion
407409
title: BLIP-Diffusion
410+
- local: api/pipelines/chroma
411+
title: Chroma
408412
- local: api/pipelines/cogvideox
409413
title: CogVideoX
410414
- local: api/pipelines/cogview3
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# ChromaTransformer2DModel
14+
15+
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
16+
17+
## ChromaTransformer2DModel
18+
19+
[[autodoc]] ChromaTransformer2DModel
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Chroma
14+
15+
<div class="flex flex-wrap space-x-1">
16+
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
17+
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
18+
</div>
19+
20+
Chroma is a text to image generation model based on Flux.
21+
22+
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
23+
24+
<Tip>
25+
26+
Chroma can use all the same optimizations as Flux.
27+
28+
</Tip>
29+
30+
## Inference (Single File)
31+
32+
The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
33+
34+
The following example demonstrates how to run Chroma from a single file.
35+
36+
Then run the following example
37+
38+
```python
39+
import torch
40+
from diffusers import ChromaTransformer2DModel, ChromaPipeline
41+
from transformers import T5EncoderModel
42+
43+
bfl_repo = "black-forest-labs/FLUX.1-dev"
44+
dtype = torch.bfloat16
45+
46+
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)
47+
48+
text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
49+
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)
50+
51+
pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)
52+
53+
pipe.enable_model_cpu_offload()
54+
55+
prompt = "A cat holding a sign that says hello world"
56+
image = pipe(
57+
prompt,
58+
guidance_scale=4.0,
59+
output_type="pil",
60+
num_inference_steps=26,
61+
generator=torch.Generator("cpu").manual_seed(0)
62+
).images[0]
63+
64+
image.save("image.png")
65+
```
66+
67+
## ChromaPipeline
68+
69+
[[autodoc]] ChromaPipeline
70+
- all
71+
- __call__

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@
159159
"AutoencoderTiny",
160160
"AutoModel",
161161
"CacheMixin",
162+
"ChromaTransformer2DModel",
162163
"CogVideoXTransformer3DModel",
163164
"CogView3PlusTransformer2DModel",
164165
"CogView4Transformer2DModel",
@@ -352,6 +353,7 @@
352353
"AuraFlowPipeline",
353354
"BlipDiffusionControlNetPipeline",
354355
"BlipDiffusionPipeline",
356+
"ChromaPipeline",
355357
"CLIPImageProjection",
356358
"CogVideoXFunControlPipeline",
357359
"CogVideoXImageToVideoPipeline",
@@ -770,6 +772,7 @@
770772
AutoencoderTiny,
771773
AutoModel,
772774
CacheMixin,
775+
ChromaTransformer2DModel,
773776
CogVideoXTransformer3DModel,
774777
CogView3PlusTransformer2DModel,
775778
CogView4Transformer2DModel,
@@ -942,6 +945,7 @@
942945
AudioLDM2UNet2DConditionModel,
943946
AudioLDMPipeline,
944947
AuraFlowPipeline,
948+
ChromaPipeline,
945949
CLIPImageProjection,
946950
CogVideoXFunControlPipeline,
947951
CogVideoXImageToVideoPipeline,

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
6161
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
6262
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
63+
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
6364
}
6465

6566

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
convert_animatediff_checkpoint_to_diffusers,
3030
convert_auraflow_transformer_checkpoint_to_diffusers,
3131
convert_autoencoder_dc_checkpoint_to_diffusers,
32+
convert_chroma_transformer_checkpoint_to_diffusers,
3233
convert_controlnet_checkpoint,
3334
convert_flux_transformer_checkpoint_to_diffusers,
3435
convert_hidream_transformer_to_diffusers,
@@ -97,6 +98,10 @@
9798
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
9899
"default_subfolder": "transformer",
99100
},
101+
"ChromaTransformer2DModel": {
102+
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
103+
"default_subfolder": "transformer",
104+
},
100105
"LTXVideoTransformer3DModel": {
101106
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
102107
"default_subfolder": "transformer",

src/diffusers/loaders/single_file_utils.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3310,3 +3310,172 @@ def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
33103310
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
33113311

33123312
return checkpoint
3313+
3314+
3315+
def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
3316+
converted_state_dict = {}
3317+
keys = list(checkpoint.keys())
3318+
3319+
for k in keys:
3320+
if "model.diffusion_model." in k:
3321+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
3322+
3323+
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
3324+
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
3325+
num_guidance_layers = (
3326+
list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401
3327+
)
3328+
mlp_ratio = 4.0
3329+
inner_dim = 3072
3330+
3331+
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
3332+
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
3333+
def swap_scale_shift(weight):
3334+
shift, scale = weight.chunk(2, dim=0)
3335+
new_weight = torch.cat([scale, shift], dim=0)
3336+
return new_weight
3337+
3338+
# guidance
3339+
converted_state_dict["distilled_guidance_layer.in_proj.bias"] = checkpoint.pop(
3340+
"distilled_guidance_layer.in_proj.bias"
3341+
)
3342+
converted_state_dict["distilled_guidance_layer.in_proj.weight"] = checkpoint.pop(
3343+
"distilled_guidance_layer.in_proj.weight"
3344+
)
3345+
converted_state_dict["distilled_guidance_layer.out_proj.bias"] = checkpoint.pop(
3346+
"distilled_guidance_layer.out_proj.bias"
3347+
)
3348+
converted_state_dict["distilled_guidance_layer.out_proj.weight"] = checkpoint.pop(
3349+
"distilled_guidance_layer.out_proj.weight"
3350+
)
3351+
for i in range(num_guidance_layers):
3352+
block_prefix = f"distilled_guidance_layer.layers.{i}."
3353+
converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
3354+
f"distilled_guidance_layer.layers.{i}.in_layer.bias"
3355+
)
3356+
converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
3357+
f"distilled_guidance_layer.layers.{i}.in_layer.weight"
3358+
)
3359+
converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
3360+
f"distilled_guidance_layer.layers.{i}.out_layer.bias"
3361+
)
3362+
converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
3363+
f"distilled_guidance_layer.layers.{i}.out_layer.weight"
3364+
)
3365+
converted_state_dict[f"distilled_guidance_layer.norms.{i}.weight"] = checkpoint.pop(
3366+
f"distilled_guidance_layer.norms.{i}.scale"
3367+
)
3368+
3369+
# context_embedder
3370+
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
3371+
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
3372+
3373+
# x_embedder
3374+
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
3375+
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
3376+
3377+
# double transformer blocks
3378+
for i in range(num_layers):
3379+
block_prefix = f"transformer_blocks.{i}."
3380+
# Q, K, V
3381+
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
3382+
context_q, context_k, context_v = torch.chunk(
3383+
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
3384+
)
3385+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
3386+
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
3387+
)
3388+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
3389+
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
3390+
)
3391+
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
3392+
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
3393+
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
3394+
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
3395+
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
3396+
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
3397+
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
3398+
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
3399+
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
3400+
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
3401+
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
3402+
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
3403+
# qk_norm
3404+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
3405+
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
3406+
)
3407+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
3408+
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
3409+
)
3410+
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
3411+
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
3412+
)
3413+
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
3414+
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
3415+
)
3416+
# ff img_mlp
3417+
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
3418+
f"double_blocks.{i}.img_mlp.0.weight"
3419+
)
3420+
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
3421+
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
3422+
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
3423+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
3424+
f"double_blocks.{i}.txt_mlp.0.weight"
3425+
)
3426+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
3427+
f"double_blocks.{i}.txt_mlp.0.bias"
3428+
)
3429+
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
3430+
f"double_blocks.{i}.txt_mlp.2.weight"
3431+
)
3432+
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
3433+
f"double_blocks.{i}.txt_mlp.2.bias"
3434+
)
3435+
# output projections.
3436+
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
3437+
f"double_blocks.{i}.img_attn.proj.weight"
3438+
)
3439+
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
3440+
f"double_blocks.{i}.img_attn.proj.bias"
3441+
)
3442+
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
3443+
f"double_blocks.{i}.txt_attn.proj.weight"
3444+
)
3445+
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
3446+
f"double_blocks.{i}.txt_attn.proj.bias"
3447+
)
3448+
3449+
# single transformer blocks
3450+
for i in range(num_single_layers):
3451+
block_prefix = f"single_transformer_blocks.{i}."
3452+
# Q, K, V, mlp
3453+
mlp_hidden_dim = int(inner_dim * mlp_ratio)
3454+
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
3455+
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
3456+
q_bias, k_bias, v_bias, mlp_bias = torch.split(
3457+
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
3458+
)
3459+
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
3460+
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
3461+
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
3462+
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
3463+
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
3464+
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
3465+
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
3466+
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
3467+
# qk norm
3468+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
3469+
f"single_blocks.{i}.norm.query_norm.scale"
3470+
)
3471+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
3472+
f"single_blocks.{i}.norm.key_norm.scale"
3473+
)
3474+
# output projections.
3475+
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
3476+
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
3477+
3478+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
3479+
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
3480+
3481+
return converted_state_dict

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
7575
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
7676
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
77+
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
7778
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
7879
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
7980
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
@@ -151,6 +152,7 @@
151152
from .transformers import (
152153
AllegroTransformer3DModel,
153154
AuraFlowTransformer2DModel,
155+
ChromaTransformer2DModel,
154156
CogVideoXTransformer3DModel,
155157
CogView3PlusTransformer2DModel,
156158
CogView4Transformer2DModel,

src/diffusers/models/embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_timestep_embedding(
3131
downscale_freq_shift: float = 1,
3232
scale: float = 1,
3333
max_period: int = 10000,
34-
):
34+
) -> torch.Tensor:
3535
"""
3636
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
3737
@@ -1325,7 +1325,7 @@ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shif
13251325
self.downscale_freq_shift = downscale_freq_shift
13261326
self.scale = scale
13271327

1328-
def forward(self, timesteps):
1328+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
13291329
t_emb = get_timestep_embedding(
13301330
timesteps,
13311331
self.num_channels,

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .t5_film_transformer import T5FilmDecoder
1818
from .transformer_2d import Transformer2DModel
1919
from .transformer_allegro import AllegroTransformer3DModel
20+
from .transformer_chroma import ChromaTransformer2DModel
2021
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
2122
from .transformer_cogview4 import CogView4Transformer2DModel
2223
from .transformer_cosmos import CosmosTransformer3DModel

0 commit comments

Comments
 (0)