Fix fp32_ln for various models #41605
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR fixes the test
test_flash_attn_2_fp32_lnfor several models:barkwas failing the test because it call_flash_attention_forwarddirectly without checking thequeriesdtype, and so the test could fail if the dtype wastorch.float32. To fix this we re-factored out a code block into a functionget_target_dtypethat takes care of infering whether to cast the fp32 tesnor to fp16 or bf16, and added a called to it before the call to FAstablelmmllamawas failing the test becauseMllamaTextSelfAttentionlacks theis_causalattribute, which was added and set to True (it's a text attention so it's causal, as discussed in Mllama fixes #39182)kosmos2but the test still fails for many many other reasonsThe list of fixed test is here: