Skip to content

Commit abb6285

Browse files
committed
messing with attention mask
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 244232f commit abb6285

File tree

4 files changed

+26
-8
lines changed

4 files changed

+26
-8
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def scaled_dot_product_attention(
7171
of the vanilla sdpa in a graph.
7272
"""
7373

74+
if attn_mask is not None:
75+
is_causal = True
76+
attn_mask = None
77+
7478
return F.scaled_dot_product_attention(
7579
query.contiguous(),
7680
key.contiguous(),
@@ -79,7 +83,7 @@ def scaled_dot_product_attention(
7983
dropout_p=dropout_p,
8084
is_causal=is_causal,
8185
scale=scale,
82-
enable_gqa=False,
86+
enable_gqa=enable_gqa,
8387
)
8488

8589

tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from transformers import Llama4ForConditionalGeneration
88
from transformers.models.llama4.modeling_llama4 import Llama4CausalLMOutputWithPast
99

10-
from ...export.interface import BaseExportPatch, ExportPatchRegistry
10+
from ...export.interface import BaseExportPatch
1111

1212

1313
# Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L1651

tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,10 @@ def test_build_run_llama4_vlm():
211211
)
212212
processor = AutoProcessor.from_pretrained(model_id)
213213

214-
config = AutoConfig.from_pretrained(model_id)
214+
config = AutoConfig.from_pretrained(
215+
model_id,
216+
# attn_implementation="eager",
217+
)
215218
config.text_config.num_hidden_layers = 2
216219
config.text_config.intermediate_size = 64
217220
config.text_config.intermediate_size_mlp = 128
@@ -251,7 +254,14 @@ def test_build_run_llama4_vlm():
251254
)
252255

253256
def _run_with_and_without_image(model, use_none=False):
254-
with apply_export_patches({"transformers_sdpa_mask": {}, "autocast_noop": {}}):
257+
with apply_export_patches(
258+
{
259+
"transformers_sdpa_mask": {},
260+
"autocast_noop": {},
261+
# "sdpa": {},
262+
"sdpa_kernel_noop": {},
263+
}
264+
):
255265
with torch.inference_mode():
256266
out_no_images = model(
257267
inputs["input_ids"],
@@ -280,7 +290,14 @@ def _run_with_and_without_image(model, use_none=False):
280290
model,
281291
(inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"]),
282292
kwargs={},
283-
patch_list=["transformers_sdpa_mask", "autocast_noop"],
293+
patch_list=[
294+
"transformers_sdpa_mask",
295+
"autocast_noop",
296+
"torch_where",
297+
"tensor_meta_device",
298+
"sdpa_kernel_noop",
299+
"sdpa",
300+
],
284301
)
285302
move_to_device(gm, model.device)
286303

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,6 @@ def _joint_transform(gm: GraphModule) -> None:
7878
["eager", "sdpa"],
7979
)
8080
def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str):
81-
if attn_implementation == "sdpa":
82-
pytest.skip("https://nvbugspro.nvidia.com/bug/5170222")
83-
8481
def verify_matcher(gm: GraphModule):
8582
"""Ensure that there is exactly one torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
8683
call in the graph. Also check that there is no repeat_kv pattern left.

0 commit comments

Comments
 (0)