Skip to content

Propose to refactor output normalization in several transformers #11850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/diffusers/models/transformers/latte_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from ..normalization import AdaLayerNorm, AdaLayerNormSingle


class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["norm_out"]

"""
A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code:
Expand Down Expand Up @@ -149,8 +150,13 @@ def __init__(

# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.norm_out = AdaLayerNorm(
embedding_dim=inner_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=False,
norm_eps=1e-6,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)

# 5. Latte other blocks.
Expand All @@ -165,6 +171,17 @@ def __init__(

self.gradient_checkpointing = False

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if "scale_shift_table" in state_dict:
scale_shift_table = state_dict.pop("scale_shift_table")
state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1]
state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0]
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -305,10 +322,7 @@ def forward(
embedded_timestep = embedded_timestep.repeat_interleave(
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
).view(-1, embedded_timestep.shape[-1])
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep)
hidden_states = self.proj_out(hidden_states)

# unpatchify
Expand Down
31 changes: 21 additions & 10 deletions src/diffusers/models/transformers/pixart_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from ..normalization import AdaLayerNorm, AdaLayerNormSingle


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -78,7 +78,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed", "norm_out"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]

@register_to_config
Expand Down Expand Up @@ -171,8 +171,13 @@ def __init__(
)

# 3. Output blocks.
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.norm_out = AdaLayerNorm(
embedding_dim=self.inner_dim,
output_dim=2 * self.inner_dim,
norm_elementwise_affine=False,
norm_eps=1e-6,
chunk_dim=1,
)
self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)

self.adaln_single = AdaLayerNormSingle(
Expand All @@ -184,6 +189,17 @@ def __init__(
in_features=self.config.caption_channels, hidden_size=self.inner_dim
)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if "scale_shift_table" in state_dict:
scale_shift_table = state_dict.pop("scale_shift_table")
state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1]
state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0]
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
Expand Down Expand Up @@ -406,12 +422,7 @@ def forward(
)

# 3. Output
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)

Expand Down
29 changes: 21 additions & 8 deletions src/diffusers/models/transformers/transformer_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from ..normalization import AdaLayerNorm, AdaLayerNormSingle


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -175,6 +175,7 @@ def forward(

class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["norm_out"]

"""
A 3D Transformer model for video-like data.
Expand Down Expand Up @@ -292,8 +293,13 @@ def __init__(
)

# 3. Output projection & norm
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.norm_out = AdaLayerNorm(
embedding_dim=self.inner_dim,
output_dim=2 * self.inner_dim,
norm_elementwise_affine=False,
norm_eps=1e-6,
chunk_dim=1,
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)

# 4. Timestep embeddings
Expand All @@ -304,6 +310,17 @@ def __init__(

self.gradient_checkpointing = False

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if "scale_shift_table" in state_dict:
scale_shift_table = state_dict.pop("scale_shift_table")
state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1]
state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0]
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -393,11 +410,7 @@ def forward(
)

# 4. Output normalization & projection
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)

# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)

Expand Down
41 changes: 33 additions & 8 deletions src/diffusers/models/transformers/transformer_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..embeddings import PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
from ..normalization import AdaLayerNorm, AdaLayerNormSingle, RMSNorm


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin

_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_no_split_modules = ["norm_out"]
_repeated_blocks = ["LTXVideoTransformerBlock"]

@register_to_config
Expand Down Expand Up @@ -356,7 +357,6 @@ def __init__(

self.proj_in = nn.Linear(in_channels, inner_dim)

self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)

self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
Expand Down Expand Up @@ -389,11 +389,40 @@ def __init__(
]
)

self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
self.norm_out = AdaLayerNorm(
embedding_dim=inner_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=False,
norm_eps=1e-6,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, out_channels)

self.gradient_checkpointing = False

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
key = "scale_shift_table"
if prefix + key in state_dict:
scale_shift_table = state_dict.pop(prefix + key)
inner_dim = scale_shift_table.shape[-1]

weight = torch.eye(inner_dim).repeat(2, 1)
bias = scale_shift_table.reshape(2, inner_dim).flatten()

state_dict[prefix + "norm_out.linear.weight"] = weight
state_dict[prefix + "norm_out.linear.bias"] = bias

if prefix + "norm_out.weight" in state_dict:
state_dict.pop(prefix + "norm_out.weight")
if prefix + "norm_out.bias" in state_dict:
state_dict.pop(prefix + "norm_out.bias")

return super(LTXVideoTransformer3DModel, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -464,11 +493,7 @@ def forward(
encoder_attention_mask=encoder_attention_mask,
)

scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]

hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep.squeeze(1))
output = self.proj_out(hidden_states)

if USE_PEFT_BACKEND:
Expand Down
47 changes: 33 additions & 14 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
from ..normalization import AdaLayerNorm, FP32LayerNorm


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -372,7 +372,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi

_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["WanTransformerBlock"]
_no_split_modules = ["WanTransformerBlock", "norm_out"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"]
Expand Down Expand Up @@ -428,12 +428,40 @@ def __init__(
)

# 4. Output norm & projection
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.norm_out = AdaLayerNorm(
embedding_dim=inner_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=False,
norm_eps=eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)

self.gradient_checkpointing = False

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
key = "scale_shift_table"
if prefix + key in state_dict:
scale_shift_table = state_dict.pop(prefix + key)
inner_dim = scale_shift_table.shape[-1]

weight = torch.eye(inner_dim).repeat(2, 1)
bias = scale_shift_table.reshape(2, inner_dim).flatten()

state_dict[prefix + "norm_out.linear.weight"] = weight
state_dict[prefix + "norm_out.linear.bias"] = bias

if prefix + "norm_out.weight" in state_dict:
state_dict.pop(prefix + "norm_out.weight")
if prefix + "norm_out.bias" in state_dict:
state_dict.pop(prefix + "norm_out.bias")

return super(WanTransformer3DModel, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -488,16 +516,7 @@ def forward(
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)

# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)

# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)

hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.norm_out(hidden_states, temb=temb)
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(
Expand Down
Loading