@@ -3310,3 +3310,172 @@ def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
3310
3310
checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
3311
3311
3312
3312
return checkpoint
3313
+
3314
+
3315
+ def convert_chroma_transformer_checkpoint_to_diffusers (checkpoint , ** kwargs ):
3316
+ converted_state_dict = {}
3317
+ keys = list (checkpoint .keys ())
3318
+
3319
+ for k in keys :
3320
+ if "model.diffusion_model." in k :
3321
+ checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
3322
+
3323
+ num_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in checkpoint if "double_blocks." in k ))[- 1 ] + 1 # noqa: C401
3324
+ num_single_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in checkpoint if "single_blocks." in k ))[- 1 ] + 1 # noqa: C401
3325
+ num_guidance_layers = (
3326
+ list (set (int (k .split ("." , 3 )[2 ]) for k in checkpoint if "distilled_guidance_layer.layers." in k ))[- 1 ] + 1 # noqa: C401
3327
+ )
3328
+ mlp_ratio = 4.0
3329
+ inner_dim = 3072
3330
+
3331
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
3332
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
3333
+ def swap_scale_shift (weight ):
3334
+ shift , scale = weight .chunk (2 , dim = 0 )
3335
+ new_weight = torch .cat ([scale , shift ], dim = 0 )
3336
+ return new_weight
3337
+
3338
+ # guidance
3339
+ converted_state_dict ["distilled_guidance_layer.in_proj.bias" ] = checkpoint .pop (
3340
+ "distilled_guidance_layer.in_proj.bias"
3341
+ )
3342
+ converted_state_dict ["distilled_guidance_layer.in_proj.weight" ] = checkpoint .pop (
3343
+ "distilled_guidance_layer.in_proj.weight"
3344
+ )
3345
+ converted_state_dict ["distilled_guidance_layer.out_proj.bias" ] = checkpoint .pop (
3346
+ "distilled_guidance_layer.out_proj.bias"
3347
+ )
3348
+ converted_state_dict ["distilled_guidance_layer.out_proj.weight" ] = checkpoint .pop (
3349
+ "distilled_guidance_layer.out_proj.weight"
3350
+ )
3351
+ for i in range (num_guidance_layers ):
3352
+ block_prefix = f"distilled_guidance_layer.layers.{ i } ."
3353
+ converted_state_dict [f"{ block_prefix } linear_1.bias" ] = checkpoint .pop (
3354
+ f"distilled_guidance_layer.layers.{ i } .in_layer.bias"
3355
+ )
3356
+ converted_state_dict [f"{ block_prefix } linear_1.weight" ] = checkpoint .pop (
3357
+ f"distilled_guidance_layer.layers.{ i } .in_layer.weight"
3358
+ )
3359
+ converted_state_dict [f"{ block_prefix } linear_2.bias" ] = checkpoint .pop (
3360
+ f"distilled_guidance_layer.layers.{ i } .out_layer.bias"
3361
+ )
3362
+ converted_state_dict [f"{ block_prefix } linear_2.weight" ] = checkpoint .pop (
3363
+ f"distilled_guidance_layer.layers.{ i } .out_layer.weight"
3364
+ )
3365
+ converted_state_dict [f"distilled_guidance_layer.norms.{ i } .weight" ] = checkpoint .pop (
3366
+ f"distilled_guidance_layer.norms.{ i } .scale"
3367
+ )
3368
+
3369
+ # context_embedder
3370
+ converted_state_dict ["context_embedder.weight" ] = checkpoint .pop ("txt_in.weight" )
3371
+ converted_state_dict ["context_embedder.bias" ] = checkpoint .pop ("txt_in.bias" )
3372
+
3373
+ # x_embedder
3374
+ converted_state_dict ["x_embedder.weight" ] = checkpoint .pop ("img_in.weight" )
3375
+ converted_state_dict ["x_embedder.bias" ] = checkpoint .pop ("img_in.bias" )
3376
+
3377
+ # double transformer blocks
3378
+ for i in range (num_layers ):
3379
+ block_prefix = f"transformer_blocks.{ i } ."
3380
+ # Q, K, V
3381
+ sample_q , sample_k , sample_v = torch .chunk (checkpoint .pop (f"double_blocks.{ i } .img_attn.qkv.weight" ), 3 , dim = 0 )
3382
+ context_q , context_k , context_v = torch .chunk (
3383
+ checkpoint .pop (f"double_blocks.{ i } .txt_attn.qkv.weight" ), 3 , dim = 0
3384
+ )
3385
+ sample_q_bias , sample_k_bias , sample_v_bias = torch .chunk (
3386
+ checkpoint .pop (f"double_blocks.{ i } .img_attn.qkv.bias" ), 3 , dim = 0
3387
+ )
3388
+ context_q_bias , context_k_bias , context_v_bias = torch .chunk (
3389
+ checkpoint .pop (f"double_blocks.{ i } .txt_attn.qkv.bias" ), 3 , dim = 0
3390
+ )
3391
+ converted_state_dict [f"{ block_prefix } attn.to_q.weight" ] = torch .cat ([sample_q ])
3392
+ converted_state_dict [f"{ block_prefix } attn.to_q.bias" ] = torch .cat ([sample_q_bias ])
3393
+ converted_state_dict [f"{ block_prefix } attn.to_k.weight" ] = torch .cat ([sample_k ])
3394
+ converted_state_dict [f"{ block_prefix } attn.to_k.bias" ] = torch .cat ([sample_k_bias ])
3395
+ converted_state_dict [f"{ block_prefix } attn.to_v.weight" ] = torch .cat ([sample_v ])
3396
+ converted_state_dict [f"{ block_prefix } attn.to_v.bias" ] = torch .cat ([sample_v_bias ])
3397
+ converted_state_dict [f"{ block_prefix } attn.add_q_proj.weight" ] = torch .cat ([context_q ])
3398
+ converted_state_dict [f"{ block_prefix } attn.add_q_proj.bias" ] = torch .cat ([context_q_bias ])
3399
+ converted_state_dict [f"{ block_prefix } attn.add_k_proj.weight" ] = torch .cat ([context_k ])
3400
+ converted_state_dict [f"{ block_prefix } attn.add_k_proj.bias" ] = torch .cat ([context_k_bias ])
3401
+ converted_state_dict [f"{ block_prefix } attn.add_v_proj.weight" ] = torch .cat ([context_v ])
3402
+ converted_state_dict [f"{ block_prefix } attn.add_v_proj.bias" ] = torch .cat ([context_v_bias ])
3403
+ # qk_norm
3404
+ converted_state_dict [f"{ block_prefix } attn.norm_q.weight" ] = checkpoint .pop (
3405
+ f"double_blocks.{ i } .img_attn.norm.query_norm.scale"
3406
+ )
3407
+ converted_state_dict [f"{ block_prefix } attn.norm_k.weight" ] = checkpoint .pop (
3408
+ f"double_blocks.{ i } .img_attn.norm.key_norm.scale"
3409
+ )
3410
+ converted_state_dict [f"{ block_prefix } attn.norm_added_q.weight" ] = checkpoint .pop (
3411
+ f"double_blocks.{ i } .txt_attn.norm.query_norm.scale"
3412
+ )
3413
+ converted_state_dict [f"{ block_prefix } attn.norm_added_k.weight" ] = checkpoint .pop (
3414
+ f"double_blocks.{ i } .txt_attn.norm.key_norm.scale"
3415
+ )
3416
+ # ff img_mlp
3417
+ converted_state_dict [f"{ block_prefix } ff.net.0.proj.weight" ] = checkpoint .pop (
3418
+ f"double_blocks.{ i } .img_mlp.0.weight"
3419
+ )
3420
+ converted_state_dict [f"{ block_prefix } ff.net.0.proj.bias" ] = checkpoint .pop (f"double_blocks.{ i } .img_mlp.0.bias" )
3421
+ converted_state_dict [f"{ block_prefix } ff.net.2.weight" ] = checkpoint .pop (f"double_blocks.{ i } .img_mlp.2.weight" )
3422
+ converted_state_dict [f"{ block_prefix } ff.net.2.bias" ] = checkpoint .pop (f"double_blocks.{ i } .img_mlp.2.bias" )
3423
+ converted_state_dict [f"{ block_prefix } ff_context.net.0.proj.weight" ] = checkpoint .pop (
3424
+ f"double_blocks.{ i } .txt_mlp.0.weight"
3425
+ )
3426
+ converted_state_dict [f"{ block_prefix } ff_context.net.0.proj.bias" ] = checkpoint .pop (
3427
+ f"double_blocks.{ i } .txt_mlp.0.bias"
3428
+ )
3429
+ converted_state_dict [f"{ block_prefix } ff_context.net.2.weight" ] = checkpoint .pop (
3430
+ f"double_blocks.{ i } .txt_mlp.2.weight"
3431
+ )
3432
+ converted_state_dict [f"{ block_prefix } ff_context.net.2.bias" ] = checkpoint .pop (
3433
+ f"double_blocks.{ i } .txt_mlp.2.bias"
3434
+ )
3435
+ # output projections.
3436
+ converted_state_dict [f"{ block_prefix } attn.to_out.0.weight" ] = checkpoint .pop (
3437
+ f"double_blocks.{ i } .img_attn.proj.weight"
3438
+ )
3439
+ converted_state_dict [f"{ block_prefix } attn.to_out.0.bias" ] = checkpoint .pop (
3440
+ f"double_blocks.{ i } .img_attn.proj.bias"
3441
+ )
3442
+ converted_state_dict [f"{ block_prefix } attn.to_add_out.weight" ] = checkpoint .pop (
3443
+ f"double_blocks.{ i } .txt_attn.proj.weight"
3444
+ )
3445
+ converted_state_dict [f"{ block_prefix } attn.to_add_out.bias" ] = checkpoint .pop (
3446
+ f"double_blocks.{ i } .txt_attn.proj.bias"
3447
+ )
3448
+
3449
+ # single transformer blocks
3450
+ for i in range (num_single_layers ):
3451
+ block_prefix = f"single_transformer_blocks.{ i } ."
3452
+ # Q, K, V, mlp
3453
+ mlp_hidden_dim = int (inner_dim * mlp_ratio )
3454
+ split_size = (inner_dim , inner_dim , inner_dim , mlp_hidden_dim )
3455
+ q , k , v , mlp = torch .split (checkpoint .pop (f"single_blocks.{ i } .linear1.weight" ), split_size , dim = 0 )
3456
+ q_bias , k_bias , v_bias , mlp_bias = torch .split (
3457
+ checkpoint .pop (f"single_blocks.{ i } .linear1.bias" ), split_size , dim = 0
3458
+ )
3459
+ converted_state_dict [f"{ block_prefix } attn.to_q.weight" ] = torch .cat ([q ])
3460
+ converted_state_dict [f"{ block_prefix } attn.to_q.bias" ] = torch .cat ([q_bias ])
3461
+ converted_state_dict [f"{ block_prefix } attn.to_k.weight" ] = torch .cat ([k ])
3462
+ converted_state_dict [f"{ block_prefix } attn.to_k.bias" ] = torch .cat ([k_bias ])
3463
+ converted_state_dict [f"{ block_prefix } attn.to_v.weight" ] = torch .cat ([v ])
3464
+ converted_state_dict [f"{ block_prefix } attn.to_v.bias" ] = torch .cat ([v_bias ])
3465
+ converted_state_dict [f"{ block_prefix } proj_mlp.weight" ] = torch .cat ([mlp ])
3466
+ converted_state_dict [f"{ block_prefix } proj_mlp.bias" ] = torch .cat ([mlp_bias ])
3467
+ # qk norm
3468
+ converted_state_dict [f"{ block_prefix } attn.norm_q.weight" ] = checkpoint .pop (
3469
+ f"single_blocks.{ i } .norm.query_norm.scale"
3470
+ )
3471
+ converted_state_dict [f"{ block_prefix } attn.norm_k.weight" ] = checkpoint .pop (
3472
+ f"single_blocks.{ i } .norm.key_norm.scale"
3473
+ )
3474
+ # output projections.
3475
+ converted_state_dict [f"{ block_prefix } proj_out.weight" ] = checkpoint .pop (f"single_blocks.{ i } .linear2.weight" )
3476
+ converted_state_dict [f"{ block_prefix } proj_out.bias" ] = checkpoint .pop (f"single_blocks.{ i } .linear2.bias" )
3477
+
3478
+ converted_state_dict ["proj_out.weight" ] = checkpoint .pop ("final_layer.linear.weight" )
3479
+ converted_state_dict ["proj_out.bias" ] = checkpoint .pop ("final_layer.linear.bias" )
3480
+
3481
+ return converted_state_dict
0 commit comments