-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Fix qwen encoder hidden states mask #12655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Improves attention mask handling for QwenImage transformer by: - Adding support for variable-length sequence masking - Implementing dynamic attention mask generation from encoder_hidden_states_mask - Ensuring RoPE embedding works correctly with padded sequences - Adding comprehensive test coverage for masked input scenarios Performance and flexibility benefits: - Enables more efficient processing of sequences with padding - Prevents padding tokens from contributing to attention computations - Maintains model performance with minimal overhead
Improves file naming convention for the Qwen image mask performance benchmark script Enhances code organization by using a more descriptive and consistent filename that clearly indicates the script's purpose
|
@cdutr it's great that you have also included the benchmarking script for fullest transparency. But we can remove that from this PR and instead have that as a GitHub gist. The benchmark numbers make sense to me. Some comments:
Also, I think a natural next step would be see how well this performs when combined with FA varlen. WDYT? @naykun what do you think about the changes? |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Thanks @sayakpaul! I removed the benchmark script, moved all tests to this gist. torch.compile testAlso tested the performance with Tested on NVIDIA A100 80GB PCIe: Also validated on RTX 4050 6GB (laptop) with similar results (2.38x speedup). The mask implementation is fully compatible with torch.compile. Image outputsTested End-to-end image generation: Successfully generated images using QwenImagePipeline and pipeline runs without errors, here is the output generated:
FA VarlenFA varlen is the natural next step, yes! I'm interested in working on it. Should I keep iterating in this PR, or should we merge it and create a new issue? The mask infrastructure this PR adds would translate well to varlen, instead of masking padding tokens, we'd pack only valid tokens using the same sequence length information |
|
Thanks for the results! Looks quite nice.
I think it's fine to first merge this PR and then we work on it afterwards. We're adding easier support for Sage and FA2 in this PR: #12439, so after that's merged, it will be quite easy to work on that (thanks to the Could we also check if the outputs deviate with and without the masks, i.e., the outputs we get on |
|
@dxqb would you maybe interested in checking this PR out as well? |

What does this PR do?
Fixes the QwenImage encoder to properly apply
encoder_hidden_states_maskwhen passed to the model. Previously, the mask parameter was accepted but ignored, causing padding tokens to incorrectly influence attention computation.Changes
QwenDoubleStreamAttnProcessor2_0to create a 2D attention mask from the 1Dencoder_hidden_states_mask, properly masking text padding tokens while keeping all image tokens unmaskedImpact
This fix enables proper Classifier-Free Guidance (CFG) batching with variable-length text sequences, which is common when batching conditional and unconditional prompts together.
Benchmark Results
Overhead: +2.8% for mask processing without padding, +18.7% with actual padding (realistic CFG scenario)
The higher overhead with padding is expected and acceptable as it represents the cost of properly handling variable-length sequences in batched inference. This is a necessary correctness fix rather than an optimization. Test ran on RTX 4070 12GB.
Fixes #12294
Before submitting
Who can review?
@yiyixuxu @sayakpaul - Would appreciate your review, especially regarding the benchmarking approach. I used a custom benchmark rather than
BenchmarkMixinbecause:Note: The benchmark file is named
benchmarking_qwenimage_mask.py(with "benchmarking" prefix) rather thanbenchmark_qwenimage_mask.pyto prevent it from being picked up byrun_all.py, since it doesn't useBenchmarkMixinand produces a different CSV schema. If you prefer, I can adapt it to use the standard format instead.Happy to adjust the approach if you have suggestions!