From 3428cc3820cea349a75985bfe7eef70170d91b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 17 Aug 2025 22:28:09 +0300 Subject: [PATCH 01/20] fix: update SkyReels-V2 documentation and moving into attn dispatcher --- docs/source/en/api/pipelines/skyreels_v2.md | 20 +++++++++---------- .../transformers/transformer_skyreels_v2.py | 3 ++- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index cd94f2a75c08..ee55db6b4ba7 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -22,7 +22,7 @@ # SkyReels-V2: Infinite-length Film Generative model -[SkyReels-V2](https://huggingface.co/papers/2504.13074) by the SkyReels Team. +[SkyReels-V2](https://huggingface.co/papers/2504.13074) by the SkyReels Team from Skywork AI. *Recent advances in video generation have been driven by diffusion models and autoregressive frameworks, yet critical challenges persist in harmonizing prompt adherence, visual quality, motion dynamics, and duration: compromises in motion dynamics to enhance temporal visual quality, constrained video duration (5-10 seconds) to prioritize resolution, and inadequate shot-aware generation stemming from general-purpose MLLMs' inability to interpret cinematic grammar, such as shot composition, actor expressions, and camera motions. These intertwined limitations hinder realistic long-form synthesis and professional film-style generation. To address these limitations, we propose SkyReels-V2, an Infinite-length Film Generative Model, that synergizes Multi-modal Large Language Model (MLLM), Multi-stage Pretraining, Reinforcement Learning, and Diffusion Forcing Framework. Firstly, we design a comprehensive structural representation of video that combines the general descriptions by the Multi-modal LLM and the detailed shot language by sub-expert models. Aided with human annotation, we then train a unified Video Captioner, named SkyCaptioner-V1, to efficiently label the video data. Secondly, we establish progressive-resolution pretraining for the fundamental video generation, followed by a four-stage post-training enhancement: Initial concept-balanced Supervised Fine-Tuning (SFT) improves baseline quality; Motion-specific Reinforcement Learning (RL) training with human-annotated and synthetic distortion data addresses dynamic artifacts; Our diffusion forcing framework with non-decreasing noise schedules enables long-video synthesis in an efficient search space; Final high-quality SFT refines visual fidelity. All the code and models are available at [this https URL](https://github.com/SkyworkAI/SkyReels-V2).* @@ -145,7 +145,6 @@ From the original repo: >You can use --ar_step 5 to enable asynchronous inference. When asynchronous inference, --causal_block_size 5 is recommended while it is not supposed to be set for synchronous generation... Asynchronous inference will take more steps to diffuse the whole sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance. ```py -# pip install ftfy import torch from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline, UniPCMultistepScheduler from diffusers.utils import export_to_video @@ -177,7 +176,7 @@ output = pipeline( overlap_history=None, # Number of frames to overlap for smooth transitions in long videos; 17 for long video generations addnoise_condition=20, # Improves consistency in long video generation ).frames[0] -export_to_video(output, "T2V.mp4", fps=24, quality=8) +export_to_video(output, "video.mp4", fps=24, quality=8) ``` @@ -239,7 +238,7 @@ prompt = "CG animation style, a small blue bird takes off from the ground, flapp output = pipeline( image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.0 ).frames[0] -export_to_video(output, "output.mp4", fps=24, quality=8) +export_to_video(output, "video.mp4", fps=24, quality=8) ``` @@ -261,7 +260,7 @@ from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingVideoToVideoPi from diffusers.utils import export_to_video, load_video -model_id = "Skywork/SkyReels-V2-DF-14B-540P-Diffusers" +model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers" vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipeline = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained( model_id, vae=vae, torch_dtype=torch.bfloat16 @@ -275,11 +274,11 @@ video = load_video("input_video.mp4") prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." output = pipeline( - video=video, prompt=prompt, height=544, width=960, guidance_scale=5.0, - num_inference_steps=30, num_frames=257, base_num_frames=97#, ar_step=5, causal_block_size=5, + video=video, prompt=prompt, height=720, width=1280, guidance_scale=5.0, overlap_history=17, + num_inference_steps=30, num_frames=257, base_num_frames=121#, ar_step=5, causal_block_size=5, ).frames[0] -export_to_video(output, "output.mp4", fps=24, quality=8) -# Total frames will be the number of frames of given video + 257 +export_to_video(output, "video.mp4", fps=24, quality=8) +# Total frames will be the number of frames of the given video + 257 ``` @@ -294,7 +293,6 @@ export_to_video(output, "output.mp4", fps=24, quality=8) Show example code ```py - # pip install ftfy import torch from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline from diffusers.utils import export_to_video @@ -326,7 +324,7 @@ export_to_video(output, "output.mp4", fps=24, quality=8) num_frames=97, guidance_scale=6.0, ).frames[0] - export_to_video(output, "output.mp4", fps=24) + export_to_video(output, "video.mp4", fps=24) ``` diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 236fca690a90..334503834a89 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -23,7 +23,8 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward -from ..attention_processor import Attention +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import ( PixArtAlphaTextProjection, From 42113fcbcee88f36ccf68b459a645bc99856e88e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 18 Aug 2025 20:00:24 +0300 Subject: [PATCH 02/20] Refactors SkyReelsV2's attention implementation --- .../transformers/transformer_skyreels_v2.py | 232 ++++++++++++++---- 1 file changed, 183 insertions(+), 49 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 334503834a89..78d142219a43 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -1,4 +1,4 @@ -# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. +# Copyright 2025 The SkyReels Team, The Wan Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,21 +39,53 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections +def _get_qkv_projections(attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if attn.cross_attention_dim_head is None: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +# Copied from diffusers.models.transformers.transformer_wan._get_added_kv_projections +def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states_img: torch.Tensor): + if attn.fused_projections: + key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) + else: + key_img = attn.add_k_proj(encoder_hidden_states_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + return key_img, value_img + + +class SkyReelsV2AttnProcessor: + _attention_backend = None -class SkyReelsV2AttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( - "SkyReelsV2AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + "SkyReelsV2AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." ) def __call__( self, - attn: Attention, + attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: encoder_hidden_states_img = None if attn.add_k_proj is not None: @@ -61,21 +93,15 @@ def __call__( image_context_length = encoder_hidden_states.shape[1] - 512 encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] encoder_hidden_states = encoder_hidden_states[:, image_context_length:] - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + query = attn.norm_q(query) + key = attn.norm_k(key) - query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) if rotary_emb is not None: @@ -90,29 +116,35 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): # I2V task hidden_states_img = None if encoder_hidden_states_img is not None: - key_img = attn.add_k_proj(encoder_hidden_states_img) + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) key_img = attn.norm_added_k(key_img) - value_img = attn.add_v_proj(encoder_hidden_states_img) - key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) - - hidden_states_img = F.scaled_dot_product_attention( - query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) + + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) - hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) + hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) - hidden_states = F.scaled_dot_product_attention( + hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, + backend=self._attention_backend, ) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) if hidden_states_img is not None: @@ -123,6 +155,115 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): return hidden_states +class SkyReelsV2Attention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = SkyReelsV2AttnProcessor + _available_processors = [SkyReelsV2AttnProcessor] + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: Optional[int] = None, + cross_attention_dim_head: Optional[int] = None, + processor=None, + is_cross_attention=None, + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = torch.nn.ModuleList( + [ + torch.nn.Linear(self.inner_dim, dim, bias=True), + torch.nn.Dropout(dropout), + ] + ) + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + + self.is_cross_attention = cross_attention_dim_head is not None + + self.set_processor(processor) + + # Copied from diffusers.models.transformers.transformer_wan.WanAttention.fuse_projections + def fuse_projections(self): + if getattr(self, "fused_projections", False): + return + + if self.cross_attention_dim_head is None: + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_qkv = nn.Linear(in_features, out_features, bias=True) + self.to_qkv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_kv = nn.Linear(in_features, out_features, bias=True) + self.to_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + if self.added_kv_proj_dim is not None: + concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_added_kv = nn.Linear(in_features, out_features, bias=True) + self.to_added_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + self.fused_projections = True + + @torch.no_grad() + # Copied from diffusers.models.transformers.transformer_wan.WanAttention.unfuse_projections + def unfuse_projections(self): + if not getattr(self, "fused_projections", False): + return + + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + if hasattr(self, "to_added_kv"): + delattr(self, "to_added_kv") + + self.fused_projections = False + + # Copied from diffusers.models.transformers.transformer_wan.WanAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs) + + # Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with WanImageEmbedding -> SkyReelsV2ImageEmbedding class SkyReelsV2ImageEmbedding(torch.nn.Module): def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): @@ -255,6 +396,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return freqs +@maybe_allow_in_graph class SkyReelsV2TransformerBlock(nn.Module): def __init__( self, @@ -270,33 +412,24 @@ def __init__( # 1. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) - self.attn1 = Attention( - query_dim=dim, + self.attn1 = SkyReelsV2Attention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, - processor=SkyReelsV2AttnProcessor2_0(), + cross_attention_dim_head=None, + processor=SkyReelsV2AttnProcessor(), ) # 2. Cross-attention - self.attn2 = Attention( - query_dim=dim, + self.attn2 = SkyReelsV2Attention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, added_kv_proj_dim=added_kv_proj_dim, - added_proj_bias=True, - processor=SkyReelsV2AttnProcessor2_0(), + cross_attention_dim_head=dim // num_heads, + processor=SkyReelsV2AttnProcessor(), ) self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() @@ -322,15 +455,15 @@ def forward( # For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim) e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e] + # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - attn_output = self.attn1( - hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask - ) + attn_output = self.attn1(norm_hidden_states, None, attention_mask, rotary_emb) hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) - attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) hidden_states = hidden_states + attn_output # 3. Feed-forward @@ -339,10 +472,11 @@ def forward( ) ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + return hidden_states -class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): +class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin): r""" A Transformer model for video-like data used in the Wan-based SkyReels-V2 model. From 7e237adbda51d149c61957e856b976d3d98252aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 18 Aug 2025 20:01:14 +0300 Subject: [PATCH 03/20] style --- .../models/transformers/transformer_skyreels_v2.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 78d142219a43..02808f7a531a 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention import FeedForward +from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -39,8 +39,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name + # Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections -def _get_qkv_projections(attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): +def _get_qkv_projections( + attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor +): # encoder_hidden_states is only passed for cross-attention if encoder_hidden_states is None: encoder_hidden_states = hidden_states @@ -455,7 +458,7 @@ def forward( # For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim) e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e] - + # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) attn_output = self.attn1(norm_hidden_states, None, attention_mask, rotary_emb) @@ -476,7 +479,9 @@ def forward( return hidden_states -class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin): +class SkyReelsV2Transformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): r""" A Transformer model for video-like data used in the Wan-based SkyReels-V2 model. From 4d72277da4636c6a84c3c122f502e49389797e3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 18 Aug 2025 20:06:37 +0300 Subject: [PATCH 04/20] up --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 02808f7a531a..a4895228c177 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -40,7 +40,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections +# TODO: Copied from doesn't work here? def _get_qkv_projections( attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor ): From 92dbf97c7faa0e6e48b54e869d94b6aa3199b54c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 18 Aug 2025 20:20:11 +0300 Subject: [PATCH 05/20] Fixes formatting in SkyReels-V2 documentation Wraps the visual demonstration section in a Markdown code block. This change corrects the rendering of ASCII diagrams and examples, improving the overall readability of the document. --- docs/source/en/api/pipelines/skyreels_v2.md | 176 ++++++++++---------- 1 file changed, 89 insertions(+), 87 deletions(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index ee55db6b4ba7..9231a7425fa0 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -44,93 +44,95 @@ The following SkyReels-V2 models are supported in Diffusers: ### A _Visual_ Demonstration - An example with these parameters: - base_num_frames=97, num_frames=97, num_inference_steps=30, ar_step=5, causal_block_size=5 - - vae_scale_factor_temporal -> 4 - num_latent_frames: (97-1)//vae_scale_factor_temporal+1 = 25 frames -> 5 blocks of 5 frames each - - base_num_latent_frames = (97-1)//vae_scale_factor_temporal+1 = 25 → blocks = 25//5 = 5 blocks - This 5 blocks means the maximum context length of the model is 25 frames in the latent space. - - Asynchronous Processing Timeline: - ┌─────────────────────────────────────────────────────────────────┐ - │ Steps: 1 6 11 16 21 26 31 36 41 46 50 │ - │ Block 1: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │ - │ Block 2: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │ - │ Block 3: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │ - │ Block 4: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │ - │ Block 5: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │ - └─────────────────────────────────────────────────────────────────┘ - - For Long Videos (num_frames > base_num_frames): - base_num_frames acts as the "sliding window size" for processing long videos. - - Example: 257-frame video with base_num_frames=97, overlap_history=17 - ┌──── Iteration 1 (frames 1-97) ────┐ - │ Processing window: 97 frames │ → 5 blocks, async processing - │ Generates: frames 1-97 │ - └───────────────────────────────────┘ - ┌────── Iteration 2 (frames 81-177) ──────┐ - │ Processing window: 97 frames │ - │ Overlap: 17 frames (81-97) from prev │ → 5 blocks, async processing - │ Generates: frames 98-177 │ - └─────────────────────────────────────────┘ - ┌────── Iteration 3 (frames 161-257) ──────┐ - │ Processing window: 97 frames │ - │ Overlap: 17 frames (161-177) from prev │ → 5 blocks, async processing - │ Generates: frames 178-257 │ - └──────────────────────────────────────────┘ - - Each iteration independently runs the asynchronous processing with its own 5 blocks. - base_num_frames controls: - 1. Memory usage (larger window = more VRAM) - 2. Model context length (must match training constraints) - 3. Number of blocks per iteration (base_num_latent_frames // causal_block_size) - - Each block takes 30 steps to complete denoising. - Block N starts at step: 1 + (N-1) x ar_step - Total steps: 30 + (5-1) x 5 = 50 steps - - - Synchronous mode (ar_step=0) would process all blocks/frames simultaneously: - ┌──────────────────────────────────────────────┐ - │ Steps: 1 ... 30 │ - │ All blocks: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │ - └──────────────────────────────────────────────┘ - Total steps: 30 steps - - - An example on how the step matrix is constructed for asynchronous processing: - Given the parameters: (num_inference_steps=30, flow_shift=8, num_frames=97, ar_step=5, causal_block_size=5) - - num_latent_frames = (97 frames - 1) // (4 temporal downsampling) + 1 = 25 - - step_template = [999, 995, 991, 986, 980, 975, 969, 963, 956, 948, - 941, 932, 922, 912, 901, 888, 874, 859, 841, 822, - 799, 773, 743, 708, 666, 615, 551, 470, 363, 216] - - The algorithm creates a 50x25 step_matrix where: - - Row 1: [999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999] - - Row 2: [995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999] - - Row 3: [991, 991, 991, 991, 991, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999] - - ... - - Row 7: [969, 969, 969, 969, 969, 995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999] - - ... - - Row 21: [799, 799, 799, 799, 799, 888, 888, 888, 888, 888, 941, 941, 941, 941, 941, 975, 975, 975, 975, 975, 999, 999, 999, 999, 999] - - ... - - Row 35: [ 0, 0, 0, 0, 0, 216, 216, 216, 216, 216, 666, 666, 666, 666, 666, 822, 822, 822, 822, 822, 901, 901, 901, 901, 901] - - ... - - Row 42: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 551, 551, 551, 551, 551, 773, 773, 773, 773, 773] - - ... - - Row 50: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 216, 216, 216, 216, 216] - - Detailed Row 6 Analysis: - - step_matrix[5]: [ 975, 975, 975, 975, 975, 999, 999, 999, 999, 999, 999, ..., 999] - - step_index[5]: [ 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 0, ..., 0] - - step_update_mask[5]: [True,True,True,True,True,True,True,True,True,True,False, ...,False] - - valid_interval[5]: (0, 25) - - Key Pattern: Block i lags behind Block i-1 by exactly ar_step=5 timesteps, creating the - staggered "diffusion forcing" effect where later blocks condition on cleaner earlier blocks. +``` +An example with these parameters: +base_num_frames=97, num_frames=97, num_inference_steps=30, ar_step=5, causal_block_size=5 + +vae_scale_factor_temporal -> 4 +num_latent_frames: (97-1)//vae_scale_factor_temporal+1 = 25 frames -> 5 blocks of 5 frames each + +base_num_latent_frames = (97-1)//vae_scale_factor_temporal+1 = 25 → blocks = 25//5 = 5 blocks +This 5 blocks means the maximum context length of the model is 25 frames in the latent space. + +Asynchronous Processing Timeline: +┌─────────────────────────────────────────────────────────────────┐ +│ Steps: 1 6 11 16 21 26 31 36 41 46 50 │ +│ Block 1: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │ +│ Block 2: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │ +│ Block 3: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │ +│ Block 4: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │ +│ Block 5: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │ +└─────────────────────────────────────────────────────────────────┘ + +For Long Videos (num_frames > base_num_frames): +base_num_frames acts as the "sliding window size" for processing long videos. + +Example: 257-frame video with base_num_frames=97, overlap_history=17 +┌──── Iteration 1 (frames 1-97) ────┐ +│ Processing window: 97 frames │ → 5 blocks, +│ Generates: frames 1-97 │ async processing +└───────────────────────────────────┘ + ┌────── Iteration 2 (frames 81-177) ──────┐ + │ Processing window: 97 frames │ + │ Overlap: 17 frames (81-97) from prev │ → 5 blocks, + │ Generates: frames 98-177 │ async processing + └─────────────────────────────────────────┘ + ┌────── Iteration 3 (frames 161-257) ──────┐ + │ Processing window: 97 frames │ + │ Overlap: 17 frames (161-177) from prev │ → 5 blocks, + │ Generates: frames 178-257 │ async processing + └──────────────────────────────────────────┘ + +Each iteration independently runs the asynchronous processing with its own 5 blocks. +base_num_frames controls: +1. Memory usage (larger window = more VRAM) +2. Model context length (must match training constraints) +3. Number of blocks per iteration (base_num_latent_frames // causal_block_size) + +Each block takes 30 steps to complete denoising. +Block N starts at step: 1 + (N-1) x ar_step +Total steps: 30 + (5-1) x 5 = 50 steps + + +Synchronous mode (ar_step=0) would process all blocks/frames simultaneously: +┌──────────────────────────────────────────────┐ +│ Steps: 1 ... 30 │ +│ All blocks: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │ +└──────────────────────────────────────────────┘ +Total steps: 30 steps + + +An example on how the step matrix is constructed for asynchronous processing: +Given the parameters: (num_inference_steps=30, flow_shift=8, num_frames=97, ar_step=5, causal_block_size=5) +- num_latent_frames = (97 frames - 1) // (4 temporal downsampling) + 1 = 25 +- step_template = [999, 995, 991, 986, 980, 975, 969, 963, 956, 948, + 941, 932, 922, 912, 901, 888, 874, 859, 841, 822, + 799, 773, 743, 708, 666, 615, 551, 470, 363, 216] + +The algorithm creates a 50x25 step_matrix where: +- Row 1: [999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999] +- Row 2: [995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999] +- Row 3: [991, 991, 991, 991, 991, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999] +- ... +- Row 7: [969, 969, 969, 969, 969, 995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999] +- ... +- Row 21: [799, 799, 799, 799, 799, 888, 888, 888, 888, 888, 941, 941, 941, 941, 941, 975, 975, 975, 975, 975, 999, 999, 999, 999, 999] +- ... +- Row 35: [ 0, 0, 0, 0, 0, 216, 216, 216, 216, 216, 666, 666, 666, 666, 666, 822, 822, 822, 822, 822, 901, 901, 901, 901, 901] +- ... +- Row 42: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 551, 551, 551, 551, 551, 773, 773, 773, 773, 773] +- ... +- Row 50: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 216, 216, 216, 216, 216] + +Detailed Row 6 Analysis: +- step_matrix[5]: [ 975, 975, 975, 975, 975, 999, 999, 999, 999, 999, 999, ..., 999] +- step_index[5]: [ 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 0, ..., 0] +- step_update_mask[5]: [True,True,True,True,True,True,True,True,True,True,False, ...,False] +- valid_interval[5]: (0, 25) + +Key Pattern: Block i lags behind Block i-1 by exactly ar_step=5 timesteps, creating the +staggered "diffusion forcing" effect where later blocks condition on cleaner earlier blocks. +``` ### Text-to-Video Generation From f09be25a722a459e919bb36531c8216d0cef36e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 18 Aug 2025 20:43:05 +0300 Subject: [PATCH 06/20] Docs: Condense example arrays in skyreels_v2 guide MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improves the readability of the `step_matrix` examples by replacing long sequences of repeated numbers with a more compact `value×count` notation. This change makes the underlying data patterns in the examples easier to understand at a glance. --- docs/source/en/api/pipelines/skyreels_v2.md | 24 ++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index 9231a7425fa0..928a2e8f934c 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -110,25 +110,25 @@ Given the parameters: (num_inference_steps=30, flow_shift=8, num_frames=97, ar_s 799, 773, 743, 708, 666, 615, 551, 470, 363, 216] The algorithm creates a 50x25 step_matrix where: -- Row 1: [999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999] -- Row 2: [995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999] -- Row 3: [991, 991, 991, 991, 991, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999] +- Row 1: [999×5, 999×5, 999×5, 999×5, 999×5] +- Row 2: [995×5, 999×5, 999×5, 999×5, 999×5] +- Row 3: [991×5, 999×5, 999×5, 999×5, 999×5] - ... -- Row 7: [969, 969, 969, 969, 969, 995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999] +- Row 7: [969×5, 995×5, 999×5, 999×5, 999×5] - ... -- Row 21: [799, 799, 799, 799, 799, 888, 888, 888, 888, 888, 941, 941, 941, 941, 941, 975, 975, 975, 975, 975, 999, 999, 999, 999, 999] +- Row 21: [799×5, 888×5, 941×5, 975×5, 999×5] - ... -- Row 35: [ 0, 0, 0, 0, 0, 216, 216, 216, 216, 216, 666, 666, 666, 666, 666, 822, 822, 822, 822, 822, 901, 901, 901, 901, 901] +- Row 35: [ 0×5, 216×5, 666×5, 822×5, 901×5] - ... -- Row 42: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 551, 551, 551, 551, 551, 773, 773, 773, 773, 773] +- Row 42: [ 0×5, 0×5, 0×5, 551×5, 773×5] - ... -- Row 50: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 216, 216, 216, 216, 216] +- Row 50: [ 0×5, 0×5, 0×5, 0×5, 216×5] Detailed Row 6 Analysis: -- step_matrix[5]: [ 975, 975, 975, 975, 975, 999, 999, 999, 999, 999, 999, ..., 999] -- step_index[5]: [ 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 0, ..., 0] -- step_update_mask[5]: [True,True,True,True,True,True,True,True,True,True,False, ...,False] -- valid_interval[5]: (0, 25) +- step_matrix[5]: [ 975×5, 999×5, 999×5, 999×5, 999×5] +- step_index[5]: [ 6×5, 1×5, 0×5, 0×5, 0×5] +- step_update_mask[5]: [True×5, True×5, False×5, False×5, False×5] +- valid_interval[5]: (0, 25) Key Pattern: Block i lags behind Block i-1 by exactly ar_step=5 timesteps, creating the staggered "diffusion forcing" effect where later blocks condition on cleaner earlier blocks. From 6856ee64419c7a4a5518d4cb432e85c094dd5771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 19 Aug 2025 11:06:08 +0300 Subject: [PATCH 07/20] Add _repeated_blocks attribute to SkyReelsV2Transformer3DModel --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index a4895228c177..adb1e18e47fa 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -529,6 +529,7 @@ class SkyReelsV2Transformer3DModel( _no_split_modules = ["SkyReelsV2TransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["SkyReelsV2TransformerBlock"] @register_to_config def __init__( From a7e7b2fa01e1e25365a6c5fcb1e630a69f6c2572 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 19 Aug 2025 11:56:56 +0300 Subject: [PATCH 08/20] Refactor rotary embedding calculations in SkyReelsV2 to separate cosine and sine frequencies --- .../transformers/transformer_skyreels_v2.py | 83 +++++++++++++------ 1 file changed, 56 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index adb1e18e47fa..b60e7bf3b5d1 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -108,13 +108,21 @@ def __call__( if rotary_emb is not None: - def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): - x_rotated = torch.view_as_complex(hidden_states.to(torch.float32).unflatten(3, (-1, 2))) - x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) - return x_out.type_as(hidden_states) - - query = apply_rotary_emb(query, rotary_emb) - key = apply_rotary_emb(key, rotary_emb) + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) # I2V task hidden_states_img = None @@ -358,7 +366,11 @@ def forward( class SkyReelsV2RotaryPosEmbed(nn.Module): def __init__( - self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, ): super().__init__() @@ -368,35 +380,52 @@ def __init__( h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs_cos = [] + freqs_sin = [] - freqs = [] for dim in [t_dim, h_dim, w_dim]: - freq = get_1d_rotary_pos_embed( - dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float32 + freq_cos, freq_sin = get_1d_rotary_pos_embed( + dim, + max_seq_len, + theta, + use_real=True, + repeat_interleave_real=True, + freqs_dtype=freqs_dtype, ) - freqs.append(freq) - self.freqs = torch.cat(freqs, dim=1) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - freqs = self.freqs.to(hidden_states.device) - freqs = freqs.split_with_sizes( - [ - self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), - self.attention_head_dim // 6, - self.attention_head_dim // 6, - ], - dim=1, - ) + split_sizes = [ + self.attention_head_dim - 2 * (self.attention_head_dim // 3), + self.attention_head_dim // 3, + self.attention_head_dim // 3, + ] + + freqs_cos = self.freqs_cos.split(split_sizes, dim=1) + freqs_sin = self.freqs_sin.split(split_sizes, dim=1) + + freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) - freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) - freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) - freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) - freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) - return freqs + return freqs_cos, freqs_sin @maybe_allow_in_graph From 07ac70d5d09b1816ff6449d8835409f1b17aa143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 19 Aug 2025 17:36:04 +0300 Subject: [PATCH 09/20] Enhance SkyReels-V2 documentation: update model loading for GPU support and remove outdated notes --- docs/source/en/api/pipelines/skyreels_v2.md | 78 ++++++--------------- 1 file changed, 22 insertions(+), 56 deletions(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index 928a2e8f934c..24c219df6b6b 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -150,19 +150,30 @@ From the original repo: import torch from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline, UniPCMultistepScheduler from diffusers.utils import export_to_video +# For faster loading into the GPU +os.environ["HF_ENABLE_PARALLEL_LOADING"] = "yes" -vae = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32) -transformer = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) + +model_id = "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers" +vae = AutoModel.from_pretrained(model_id, + subfolder="vae", + torch_dtype=torch.float32, + device_map="cuda") pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained( - "Skywork/SkyReels-V2-DF-14B-540P-Diffusers", + model_id, vae=vae, - transformer=transformer, - torch_dtype=torch.bfloat16 + torch_dtype=torch.bfloat16, + device_map="cuda" ) flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift) -pipeline = pipeline.to("cuda") + +# Some acceleration helpers +# Be sure to install Flash Attention: https://github.com/Dao-AILab/flash-attention#installation-and-features +# Normally 14 min., with compile_repeated_blocks(fullgraph=True) 12 min., with Flash Attention too 5.5 min. +#pipeline.transformer.set_attention_backend("flash") +#pipeline.transformer.compile_repeated_blocks(fullgraph=True) prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." @@ -200,13 +211,12 @@ from diffusers.utils import export_to_video, load_image model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers" -vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32, device_map="cuda") pipeline = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained( - model_id, vae=vae, torch_dtype=torch.bfloat16 + model_id, vae=vae, torch_dtype=torch.bfloat16, device_map="cuda" ) flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift) -pipeline.to("cuda") first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png") last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png") @@ -263,13 +273,12 @@ from diffusers.utils import export_to_video, load_video model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers" -vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32, device_map="cuda") pipeline = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained( - model_id, vae=vae, torch_dtype=torch.bfloat16 + model_id, vae=vae, torch_dtype=torch.bfloat16, device_map="cuda" ) flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift) -pipeline.to("cuda") video = load_video("input_video.mp4") @@ -286,50 +295,7 @@ export_to_video(output, "video.mp4", fps=24, quality=8) - -## Notes - -- SkyReels-V2 supports LoRAs with [`~loaders.SkyReelsV2LoraLoaderMixin.load_lora_weights`]. - -
- Show example code - - ```py - import torch - from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline - from diffusers.utils import export_to_video - - vae = AutoModel.from_pretrained( - "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32 - ) - pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained( - "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", vae=vae, torch_dtype=torch.bfloat16 - ) - pipeline.to("cuda") - - pipeline.load_lora_weights("benjamin-paine/steamboat-willie-1.3b", adapter_name="steamboat-willie") - pipeline.set_adapters("steamboat-willie") - - pipeline.enable_model_cpu_offload() - - # use "steamboat willie style" to trigger the LoRA - prompt = """ - steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, - revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in - for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. - Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic - shadows and warm highlights. Medium composition, front view, low angle, with depth of field. - """ - - output = pipeline( - prompt=prompt, - num_frames=97, - guidance_scale=6.0, - ).frames[0] - export_to_video(output, "video.mp4", fps=24) - ``` - -
+`SkyReelsV2Pipeline` and `SkyReelsV2ImageToVideoPipeline` are also available without Diffusion Forcing framework applied. ## SkyReelsV2DiffusionForcingPipeline From 6e4cc723f0cdf74d4755b6b6cfb179ef7bd9a06a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 19 Aug 2025 17:44:41 +0300 Subject: [PATCH 10/20] up --- docs/source/en/api/pipelines/skyreels_v2.md | 2 +- src/diffusers/models/transformers/transformer_skyreels_v2.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index 24c219df6b6b..ffe37854bb3d 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -171,7 +171,7 @@ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.conf # Some acceleration helpers # Be sure to install Flash Attention: https://github.com/Dao-AILab/flash-attention#installation-and-features -# Normally 14 min., with compile_repeated_blocks(fullgraph=True) 12 min., with Flash Attention too 5.5 min. +# Normally 14 min., with compile_repeated_blocks(fullgraph=True) 12 min., with Flash Attention too 5.5 min at A100. #pipeline.transformer.set_attention_backend("flash") #pipeline.transformer.compile_repeated_blocks(fullgraph=True) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index b60e7bf3b5d1..51e690ae0fa4 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -40,7 +40,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# TODO: Copied from doesn't work here? def _get_qkv_projections( attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor ): From dbe245479b14787b8d19e0f017e9c3c60f1889e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 19 Aug 2025 17:47:27 +0300 Subject: [PATCH 11/20] up --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 51e690ae0fa4..559b8b2b3dab 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -274,7 +274,7 @@ def forward( return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs) -# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with WanImageEmbedding -> SkyReelsV2ImageEmbedding +# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with Wan -> SkyReelsV2 class SkyReelsV2ImageEmbedding(torch.nn.Module): def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): super().__init__() @@ -363,6 +363,7 @@ def forward( return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image +# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed with Wan -> SkyReelsV2 class SkyReelsV2RotaryPosEmbed(nn.Module): def __init__( self, From 4743c7e09a3329a9cf2eb64dd991f254ad0d2c5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 19 Aug 2025 17:51:21 +0300 Subject: [PATCH 12/20] Update model_id in SkyReels-V2 documentation --- docs/source/en/api/pipelines/skyreels_v2.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index ffe37854bb3d..10a0f925b7dc 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -210,7 +210,7 @@ from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPi from diffusers.utils import export_to_video, load_image -model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers" +model_id = "Skywork/SkyReels-V2-DF-1.3B-720P-Diffusers" vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32, device_map="cuda") pipeline = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained( model_id, vae=vae, torch_dtype=torch.bfloat16, device_map="cuda" @@ -272,7 +272,7 @@ from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingVideoToVideoPi from diffusers.utils import export_to_video, load_video -model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers" +model_id = "Skywork/SkyReels-V2-DF-1.3B-720P-Diffusers" vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32, device_map="cuda") pipeline = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained( model_id, vae=vae, torch_dtype=torch.bfloat16, device_map="cuda" From aaf247072206afdc9e3fdb2ab271a6fef0ab3344 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 19 Aug 2025 18:54:04 +0300 Subject: [PATCH 13/20] up --- docs/source/en/api/pipelines/skyreels_v2.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index 10a0f925b7dc..7f6e202c010c 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -44,7 +44,7 @@ The following SkyReels-V2 models are supported in Diffusers: ### A _Visual_ Demonstration -``` +```text An example with these parameters: base_num_frames=97, num_frames=97, num_inference_steps=30, ar_step=5, causal_block_size=5 From c88cb1630d37d6916a6ab0da6295c4510119b662 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 21 Aug 2025 17:44:47 +0300 Subject: [PATCH 14/20] refactor: remove device_map parameter for model loading and add pipeline.to("cuda") for GPU allocation --- docs/source/en/api/pipelines/skyreels_v2.md | 26 ++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index 7f6e202c010c..a18226cdf57e 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -150,29 +150,27 @@ From the original repo: import torch from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline, UniPCMultistepScheduler from diffusers.utils import export_to_video -# For faster loading into the GPU -os.environ["HF_ENABLE_PARALLEL_LOADING"] = "yes" model_id = "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers" -vae = AutoModel.from_pretrained(model_id, - subfolder="vae", - torch_dtype=torch.float32, - device_map="cuda") +vae = AutoModel.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained( model_id, vae=vae, torch_dtype=torch.bfloat16, - device_map="cuda" ) +pipeline.to("cuda") flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift) # Some acceleration helpers # Be sure to install Flash Attention: https://github.com/Dao-AILab/flash-attention#installation-and-features -# Normally 14 min., with compile_repeated_blocks(fullgraph=True) 12 min., with Flash Attention too 5.5 min at A100. -#pipeline.transformer.set_attention_backend("flash") +# Normally 14 min., with compile_repeated_blocks(fullgraph=True) 12 min., with Flash Attention too less min. at A100. +# If you want to follow the original implementation in terms of attentions: +#for block in pipeline.transformer.blocks: +# block.attn1.set_attention_backend("_native_cudnn") +# block.attn2.set_attention_backend("flash_varlen") # or "_flash_varlen_3" #pipeline.transformer.compile_repeated_blocks(fullgraph=True) prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." @@ -211,10 +209,11 @@ from diffusers.utils import export_to_video, load_image model_id = "Skywork/SkyReels-V2-DF-1.3B-720P-Diffusers" -vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32, device_map="cuda") +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipeline = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained( - model_id, vae=vae, torch_dtype=torch.bfloat16, device_map="cuda" + model_id, vae=vae, torch_dtype=torch.bfloat16 ) +pipeline.to("cuda") flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift) @@ -273,10 +272,11 @@ from diffusers.utils import export_to_video, load_video model_id = "Skywork/SkyReels-V2-DF-1.3B-720P-Diffusers" -vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32, device_map="cuda") +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipeline = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained( - model_id, vae=vae, torch_dtype=torch.bfloat16, device_map="cuda" + model_id, vae=vae, torch_dtype=torch.bfloat16 ) +pipeline.to("cuda") flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift) From eed29536bbeb9103d443158efc1e2e6ba5b41e33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 21 Aug 2025 20:17:31 +0300 Subject: [PATCH 15/20] fix: update copyright year to 2025 in skyreels_v2.md --- docs/source/en/api/pipelines/skyreels_v2.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index a18226cdf57e..a55f2315221b 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -1,4 +1,4 @@ -