Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(self, config: MllamaVisionConfig):
self.head_dim = config.hidden_size // config.attention_heads
self.scaling = self.head_dim**-0.5
self.num_key_value_groups = 1
self.is_causal = False

self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
Expand Down Expand Up @@ -584,6 +585,7 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int):
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.layer_idx = layer_idx
self.is_causal = True

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
Expand Down Expand Up @@ -1028,6 +1030,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
class MllamaVisionModel(MllamaPreTrainedModel):
config_class = MllamaVisionConfig
base_model_prefix = "vision_model"
_supports_flash_attn_2 = False # the vision model always adds a 4D attn mask which is not supported by FA2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! IMO we can still run FA2 with vision module but we need to prepare the mask correctly. In text models usually for FA2, we keep the 2D mask and don't expand it to 4D.

We can do similar thing in Mllama and skip the Reshape to 2D and create 4D attention mask part in case of FA2. We might need to check cross attention as well, which also uses 4D mask


def __init__(self, config: MllamaVisionConfig):
super().__init__(config)
Expand Down Expand Up @@ -1617,6 +1620,7 @@ def forward(
class MllamaModel(MllamaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
_supports_flash_attn_2 = False # the vision model does not support FA2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_supports_flash_attn_2 = False should be defined only on the module that doesn't support it, i.e. on MllamaVisionModel. I guess it's the issue with tests is that they don't check all submodules for _supports_flash_attn_2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I was wondering about that. I will check it out and revert the change


def __init__(self, config: MllamaConfig):
super().__init__(config)
Expand Down Expand Up @@ -1778,6 +1782,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
}
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
_tied_weights_keys = ["lm_head.weight"]
_supports_flash_attn_2 = False # the vision model does not support FA2

def __init__(self, config: MllamaConfig):
super().__init__(config)
Expand Down
13 changes: 11 additions & 2 deletions tests/models/mllama/test_modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def test_11b_model_integration_generate(self):
("xpu", 3): "If I had to write a haiku for this one, it would be:.\\nA dock on a lake.\\nA mountain in the distance.\\nA long exposure.",
("cuda", 7): "If I had to write a haiku for this one, it would be:.\\nA dock in the lake.\\nA mountain in the distance.\\nA long exposure.",
("cuda", 8): 'If I had to write a haiku for this one, it would be:.\\nA dock in the lake.\\nA mountain in the distance.\\nA long exposure.',
("rocm", (9, 5)): "If I had to write a haiku for this one, it would be:.\\nA dock on a lake.\\nA mountain in the distance.\\nA long exposure.",
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
Expand Down Expand Up @@ -582,6 +583,7 @@ def test_11b_model_integration_generate_text_only(self):
("xpu", 3): "If I had to write a haiku about my life, I would write:\nLife is a messy tapestry\n Threads of joy and sorrow\nWeft of memories",
("cuda", 7): "If I had to write a haiku about my life, I would write:\nLife is a messy stream\nRipples of joy and pain\nFlowing, ever",
("cuda", 8): "If I had to write a haiku about my life, I would write:\nLife is a messy stream\nRipples of joy and pain\nFlowing, ever",
("rocm", (9, 5)): "If I had to write a haiku about my cat, I would write:\nWhiskers twitching bright\nMoonlight dancing on her fur\nFurry little",
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
Expand Down Expand Up @@ -621,6 +623,8 @@ def test_11b_model_integration_forward(self):
("xpu", 3): torch.tensor([9.1562, 8.9141, 5.0664, 1.6855, 3.2324], dtype=actual_logits.dtype),
("cuda", 7): torch.tensor([9.0781, 8.8750, 5.0781, 1.6221, 3.2207], dtype=actual_logits.dtype),
("cuda", 8): torch.tensor([9.0703, 8.8750, 5.0781, 1.6279, 3.2207], dtype=actual_logits.dtype),
# NOTE: rocm logits are quite a bit off, we should investigate. Generation makes sense though.
("rocm", (9, 5)): torch.tensor([9.3359, 9.1641, 5.3867, 2.2090, 3.3379], dtype=actual_logits.dtype),
}
)

Expand Down Expand Up @@ -666,6 +670,7 @@ def test_11b_model_integration_batched_generate(self):
("xpu", 3): "If I had to write a haiku for this one, it would be:.\\nA dock on a lake.\\nA mountain in the distance.\\nA long exposure.",
("cuda", 7): "If I had to write a haiku for this one, it would be:.\\nA dock on a lake.\\nA mountain in the distance.\\nA long exposure.",
("cuda", 8): 'If I had to write a haiku for this one, it would be:.\\nA dock in the lake.\\nA mountain in the distance.\\nA long exposure.',
("rocm", (9, 5)): "If I had to write a haiku for this one, it would be:.\\nA dock on a lake.\\nA mountain in the distance.\\nA long exposure.",
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
Expand All @@ -683,6 +688,7 @@ def test_11b_model_integration_batched_generate(self):
("xpu", 3): "This image shows\nI'm not able to provide information on the person in this image. I can give you an idea of what's happening",
("cuda", 7): "This image shows\nI'm not able to provide information on the person in this image. I can give you an idea of what's happening",
("cuda", 8): "This image shows\nI'm not able to provide information on the person in this image. I can give you an idea of what's happening",
("rocm", (9, 5)): "This image shows\nThe image depicts a person named I'm not able to provide that information. I'm not able to provide that information.",
}
) # fmt: skip
expected_output = expected_outputs.get_expectation()
Expand Down Expand Up @@ -743,9 +749,12 @@ def test_11b_model_integration_multi_image_generate(self):
generated_output = output[0][prompt_len:]
decoded_output = processor.decode(generated_output, skip_special_tokens=False)

# model should response about "stop sign", however it responses about "dock"
# On NVIDIA, the model should response about "stop sign", however it responses about "dock"
# this happens only in quantized version, bfloat16 works fine
expected_output = "This image shows a long wooden dock extending out into a lake. The dock is made of wooden planks and has a railing"
expected_output = Expectations({
("cuda", None): "This image shows a long wooden dock extending out into a lake. The dock is made of wooden planks and has a railing",
("rocm", (9, 5)): "The image shows a long, red, octagonal stop sign with the word \"STOP\" in white letters. The sign is",
}).get_expectation() # fmt: skip

self.assertEqual(
decoded_output,
Expand Down