diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c0fa031b9faf..f23e20b1d855 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -330,6 +330,31 @@ def __call__( joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) + # Convert encoder_hidden_states_mask to 2D attention mask if provided. + if encoder_hidden_states_mask is not None and attention_mask is None: + batch_size = hidden_states.shape[0] + image_seq_len = hidden_states.shape[1] + text_seq_len = encoder_hidden_states.shape[1] + + if encoder_hidden_states_mask.shape[0] != batch_size: + raise ValueError( + f"encoder_hidden_states_mask batch size ({encoder_hidden_states_mask.shape[0]}) " + f"must match hidden_states batch size ({batch_size})" + ) + if encoder_hidden_states_mask.shape[1] != text_seq_len: + raise ValueError( + f"encoder_hidden_states_mask sequence length ({encoder_hidden_states_mask.shape[1]}) " + f"must match encoder_hidden_states sequence length ({text_seq_len})" + ) + + text_attention_mask = encoder_hidden_states_mask.bool() + image_attention_mask = torch.ones( + (batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device + ) + + joint_attention_mask_1d = torch.cat([text_attention_mask, image_attention_mask], dim=1) + attention_mask = joint_attention_mask_1d[:, None, None, :] * joint_attention_mask_1d[:, None, :, None] + # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, @@ -630,7 +655,15 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + # Use padded sequence length for RoPE when mask is present. + # The attention mask will handle excluding padding tokens. + if encoder_hidden_states_mask is not None: + txt_seq_lens_for_rope = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0] + else: + txt_seq_lens_for_rope = ( + txt_seq_lens if txt_seq_lens is not None else [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0] + ) + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens_for_rope, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index b24fa90503ef..352037aa0534 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -91,6 +91,124 @@ def test_gradient_checkpointing_is_applied(self): expected_set = {"QwenImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_attention_mask_with_padding(self): + """Test that encoder_hidden_states_mask properly handles padded sequences.""" + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device).eval() + + batch_size = 2 + height = width = 4 + num_latent_channels = embedding_dim = 16 + text_seq_len = 7 + vae_scale_factor = 4 + + # Create inputs with padding + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, text_seq_len, embedding_dim)).to(torch_device) + + # First sample: 5 real tokens, 2 padding + # Second sample: 3 real tokens, 4 padding + encoder_hidden_states_mask = torch.tensor( + [[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0]], dtype=torch.long + ).to(torch_device) + + # Zero out padding in embeddings + encoder_hidden_states = encoder_hidden_states * encoder_hidden_states_mask.unsqueeze(-1).float() + + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size + txt_seq_lens = encoder_hidden_states_mask.sum(dim=1).tolist() + + inputs_with_mask = { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + # Run with proper mask + with torch.no_grad(): + output_with_mask = model(**inputs_with_mask).sample + + # Run with all-ones mask (treating padding as real tokens) + inputs_without_mask = { + "hidden_states": hidden_states.clone(), + "encoder_hidden_states": encoder_hidden_states.clone(), + "encoder_hidden_states_mask": torch.ones_like(encoder_hidden_states_mask), + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": [text_seq_len] * batch_size, + } + + with torch.no_grad(): + output_without_mask = model(**inputs_without_mask).sample + + # Outputs should differ when mask is applied correctly + diff = (output_with_mask - output_without_mask).abs().mean().item() + assert diff > 1e-5, f"Mask appears to be ignored (diff={diff})" + + def test_attention_mask_padding_isolation(self): + """Test that changing padding content doesn't affect output when mask is used.""" + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device).eval() + + batch_size = 2 + height = width = 4 + num_latent_channels = embedding_dim = 16 + text_seq_len = 7 + vae_scale_factor = 4 + + # Create inputs + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, text_seq_len, embedding_dim)).to(torch_device) + encoder_hidden_states_mask = torch.tensor( + [[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0]], dtype=torch.long + ).to(torch_device) + + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size + txt_seq_lens = encoder_hidden_states_mask.sum(dim=1).tolist() + + inputs1 = { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + with torch.no_grad(): + output1 = model(**inputs1).sample + + # Modify padding content with large noise + encoder_hidden_states2 = encoder_hidden_states.clone() + mask = encoder_hidden_states_mask.unsqueeze(-1).float() + noise = torch.randn_like(encoder_hidden_states2) * 10.0 + encoder_hidden_states2 = encoder_hidden_states2 + noise * (1 - mask) + + inputs2 = { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states2, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + with torch.no_grad(): + output2 = model(**inputs2).sample + + # Outputs should be nearly identical (padding is masked out) + diff = (output1 - output2).abs().mean().item() + assert diff < 1e-4, f"Padding content affected output (diff={diff})" + class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel