Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,31 @@ def __call__(
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)

# Convert encoder_hidden_states_mask to 2D attention mask if provided.
if encoder_hidden_states_mask is not None and attention_mask is None:
batch_size = hidden_states.shape[0]
image_seq_len = hidden_states.shape[1]
text_seq_len = encoder_hidden_states.shape[1]

if encoder_hidden_states_mask.shape[0] != batch_size:
raise ValueError(
f"encoder_hidden_states_mask batch size ({encoder_hidden_states_mask.shape[0]}) "
f"must match hidden_states batch size ({batch_size})"
)
if encoder_hidden_states_mask.shape[1] != text_seq_len:
raise ValueError(
f"encoder_hidden_states_mask sequence length ({encoder_hidden_states_mask.shape[1]}) "
f"must match encoder_hidden_states sequence length ({text_seq_len})"
)

text_attention_mask = encoder_hidden_states_mask.bool()
image_attention_mask = torch.ones(
(batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device
)

joint_attention_mask_1d = torch.cat([text_attention_mask, image_attention_mask], dim=1)
attention_mask = joint_attention_mask_1d[:, None, None, :] * joint_attention_mask_1d[:, None, :, None]

# Compute joint attention
joint_hidden_states = dispatch_attention_fn(
joint_query,
Expand Down Expand Up @@ -630,7 +655,15 @@ def forward(
else self.time_text_embed(timestep, guidance, hidden_states)
)

image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
# Use padded sequence length for RoPE when mask is present.
# The attention mask will handle excluding padding tokens.
if encoder_hidden_states_mask is not None:
txt_seq_lens_for_rope = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0]
else:
txt_seq_lens_for_rope = (
txt_seq_lens if txt_seq_lens is not None else [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0]
)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens_for_rope, device=hidden_states.device)

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down
118 changes: 118 additions & 0 deletions tests/models/transformers/test_models_transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,124 @@ def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_attention_mask_with_padding(self):
"""Test that encoder_hidden_states_mask properly handles padded sequences."""
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device).eval()

batch_size = 2
height = width = 4
num_latent_channels = embedding_dim = 16
text_seq_len = 7
vae_scale_factor = 4

# Create inputs with padding
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, text_seq_len, embedding_dim)).to(torch_device)

# First sample: 5 real tokens, 2 padding
# Second sample: 3 real tokens, 4 padding
encoder_hidden_states_mask = torch.tensor(
[[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0]], dtype=torch.long
).to(torch_device)

# Zero out padding in embeddings
encoder_hidden_states = encoder_hidden_states * encoder_hidden_states_mask.unsqueeze(-1).float()

timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
txt_seq_lens = encoder_hidden_states_mask.sum(dim=1).tolist()

inputs_with_mask = {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
"txt_seq_lens": txt_seq_lens,
}

# Run with proper mask
with torch.no_grad():
output_with_mask = model(**inputs_with_mask).sample

# Run with all-ones mask (treating padding as real tokens)
inputs_without_mask = {
"hidden_states": hidden_states.clone(),
"encoder_hidden_states": encoder_hidden_states.clone(),
"encoder_hidden_states_mask": torch.ones_like(encoder_hidden_states_mask),
"timestep": timestep,
"img_shapes": img_shapes,
"txt_seq_lens": [text_seq_len] * batch_size,
}

with torch.no_grad():
output_without_mask = model(**inputs_without_mask).sample

# Outputs should differ when mask is applied correctly
diff = (output_with_mask - output_without_mask).abs().mean().item()
assert diff > 1e-5, f"Mask appears to be ignored (diff={diff})"

def test_attention_mask_padding_isolation(self):
"""Test that changing padding content doesn't affect output when mask is used."""
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device).eval()

batch_size = 2
height = width = 4
num_latent_channels = embedding_dim = 16
text_seq_len = 7
vae_scale_factor = 4

# Create inputs
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, text_seq_len, embedding_dim)).to(torch_device)
encoder_hidden_states_mask = torch.tensor(
[[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0]], dtype=torch.long
).to(torch_device)

timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
txt_seq_lens = encoder_hidden_states_mask.sum(dim=1).tolist()

inputs1 = {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
"txt_seq_lens": txt_seq_lens,
}

with torch.no_grad():
output1 = model(**inputs1).sample

# Modify padding content with large noise
encoder_hidden_states2 = encoder_hidden_states.clone()
mask = encoder_hidden_states_mask.unsqueeze(-1).float()
noise = torch.randn_like(encoder_hidden_states2) * 10.0
encoder_hidden_states2 = encoder_hidden_states2 + noise * (1 - mask)

inputs2 = {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states2,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
"txt_seq_lens": txt_seq_lens,
}

with torch.no_grad():
output2 = model(**inputs2).sample

# Outputs should be nearly identical (padding is masked out)
diff = (output1 - output2).abs().mean().item()
assert diff < 1e-4, f"Padding content affected output (diff={diff})"


class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
Expand Down
Loading