Skip to content

Conversation

remi-or
Copy link
Collaborator

@remi-or remi-or commented Jul 2, 2025

This PR adds the is_causal attribute to some Attention modules in mllama and disables FA2 for the MllamaVisionModel .
Some tests used to fail when is_causal was missing: MllamaForCausalLMModelTest::test_flash_attn_2_fp32_ln, MllamaForConditionalGenerationModelTest::test_eager_matches_fa2_generate, ...and when it was added FA2 failed for these tests, on both Mi355 and A100. This is because the vision model uses a 4D attn mask which is not supported by FA2.
Not an expert in VLM so can you chek is_causal is right @zucchini-nlp please?

I also added Expectations for AMD MI355. After these changes, we go from 15 failed, 185 passed, 94 skipped to 196 passed, 98 skipped on AMD MI355.

@remi-or remi-or requested a review from zucchini-nlp July 2, 2025 16:33
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Contributor

github-actions bot commented Jul 2, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: mllama

class MllamaVisionModel(MllamaPreTrainedModel):
config_class = MllamaVisionConfig
base_model_prefix = "vision_model"
_supports_flash_attn_2 = False # the vision model always adds a 4D attn mask which is not supported by FA2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! IMO we can still run FA2 with vision module but we need to prepare the mask correctly. In text models usually for FA2, we keep the 2D mask and don't expand it to 4D.

We can do similar thing in Mllama and skip the Reshape to 2D and create 4D attention mask part in case of FA2. We might need to check cross attention as well, which also uses 4D mask

class MllamaModel(MllamaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
_supports_flash_attn_2 = False # the vision model does not support FA2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_supports_flash_attn_2 = False should be defined only on the module that doesn't support it, i.e. on MllamaVisionModel. I guess it's the issue with tests is that they don't check all submodules for _supports_flash_attn_2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I was wondering about that. I will check it out and revert the change

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants