-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Mllama fixes #39182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Mllama fixes #39182
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
[For maintainers] Suggested jobs to run (before merge) run-slow: mllama |
class MllamaVisionModel(MllamaPreTrainedModel): | ||
config_class = MllamaVisionConfig | ||
base_model_prefix = "vision_model" | ||
_supports_flash_attn_2 = False # the vision model always adds a 4D attn mask which is not supported by FA2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch! IMO we can still run FA2 with vision module but we need to prepare the mask correctly. In text models usually for FA2, we keep the 2D mask and don't expand it to 4D.
We can do similar thing in Mllama and skip the Reshape to 2D and create 4D attention mask
part in case of FA2. We might need to check cross attention as well, which also uses 4D mask
class MllamaModel(MllamaPreTrainedModel): | ||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"} | ||
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting | ||
_supports_flash_attn_2 = False # the vision model does not support FA2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_supports_flash_attn_2 = False
should be defined only on the module that doesn't support it, i.e. on MllamaVisionModel
. I guess it's the issue with tests is that they don't check all submodules for _supports_flash_attn_2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I was wondering about that. I will check it out and revert the change
This PR adds the
is_causal
attribute to some Attention modules in mllama and disables FA2 for theMllamaVisionModel
.Some tests used to fail when
is_causal
was missing:MllamaForCausalLMModelTest::test_flash_attn_2_fp32_ln
,MllamaForConditionalGenerationModelTest::test_eager_matches_fa2_generate
, ...and when it was added FA2 failed for these tests, on both Mi355 and A100. This is because the vision model uses a 4D attn mask which is not supported by FA2.Not an expert in VLM so can you chek
is_causal
is right @zucchini-nlp please?I also added Expectations for AMD MI355. After these changes, we go from
15 failed, 185 passed, 94 skipped
to196 passed, 98 skipped
on AMD MI355.