@@ -251,7 +251,6 @@ def __call__(
251
251
joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
252
252
callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
253
253
callback_on_step_end_tensor_inputs : Optional [List [str ]] = None ,
254
- max_sequence_length : int = 128 ,
255
254
step_callback : Callable [[PipelineIntermediateState ], None ] = None ,
256
255
):
257
256
r"""
@@ -342,7 +341,6 @@ def __call__(
342
341
prompt_embeds = prompt_embeds ,
343
342
negative_prompt_embeds = negative_prompt_embeds ,
344
343
callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
345
- max_sequence_length = max_sequence_length ,
346
344
)
347
345
348
346
self ._guidance_scale = guidance_scale
@@ -416,15 +414,15 @@ def __call__(
416
414
prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
417
415
418
416
# Init Invoke step callback
419
- # step_callback(
420
- # PipelineIntermediateState(
421
- # step=0,
422
- # order=1,
423
- # total_steps=num_inference_steps,
424
- # timestep=int(timesteps[0]),
425
- # latents=latents,
426
- # ),
427
- # )
417
+ step_callback (
418
+ PipelineIntermediateState (
419
+ step = 0 ,
420
+ order = 1 ,
421
+ total_steps = num_inference_steps ,
422
+ timestep = int (timesteps [0 ]),
423
+ latents = latents . view ( 1 , 64 , 64 , 4 , 2 , 2 ). permute ( 0 , 3 , 1 , 4 , 2 , 5 ). reshape ( 1 , 4 , 128 , 128 ) ,
424
+ ),
425
+ )
428
426
429
427
# EYAL - added the CFG loop
430
428
# 7. Denoising loop
@@ -513,15 +511,15 @@ def __call__(
513
511
# call the callback, if provided
514
512
if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
515
513
progress_bar .update ()
516
- # step_callback(
517
- # PipelineIntermediateState(
518
- # step=i + 1,
519
- # order=1,
520
- # total_steps=num_inference_steps,
521
- # timestep=int(t),
522
- # latents=latents,
523
- # ),
524
- # )
514
+ step_callback (
515
+ PipelineIntermediateState (
516
+ step = i + 1 ,
517
+ order = 1 ,
518
+ total_steps = num_inference_steps ,
519
+ timestep = int (t ),
520
+ latents = latents . view ( 1 , 64 , 64 , 4 , 2 , 2 ). permute ( 0 , 3 , 1 , 4 , 2 , 5 ). reshape ( 1 , 4 , 128 , 128 ) ,
521
+ ),
522
+ )
525
523
526
524
if output_type == "latent" :
527
525
image = latents
0 commit comments