Skip to content

Commit 244232f

Browse files
committed
wip for llama4 patch debugging
1 parent 28c666e commit 244232f

File tree

4 files changed

+92
-33
lines changed

4 files changed

+92
-33
lines changed

tensorrt_llm/_torch/auto_deploy/llm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ def __call__(
6666
# TODO: can we avoid the extra tolist() here eventually?
6767
token_ids = all_args.pop("input_ids")
6868
assert token_ids.shape[0] == 1, "messages should be unbatched at this point."
69-
return token_ids[0].tolist(), {"multimodal_data": all_args} if all_args else None
69+
if all_args:
70+
extra_processed_inputs = {"multimodal_data": all_args}
71+
else:
72+
extra_processed_inputs = None
73+
return token_ids[0].tolist(), extra_processed_inputs
7074
else:
7175
token_ids = self.tokenizer.encode(inputs["prompt"], **kwargs)
7276
return token_ids, None

tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def _no_vision_branch(inputs_embeds, pixel_values, input_ids):
159159
)
160160

161161

162-
@ExportPatchRegistry.register("hf_llama4_vision")
162+
# @ExportPatchRegistry.register("hf_llama4_vision")
163163
class Llama4VisionPatch(BaseExportPatch):
164164
"""Patch for Llama4ForConditionalGeneration to make it compatible with torch.export.
165165

tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,15 @@ def _apply(
6666
cm.info.set_example_sequence(**factory.get_example_inputs())
6767

6868
# export the model to a graph module
69-
gm = torch_export_to_gm(
70-
model,
71-
args=cm.args,
72-
dynamic_shapes=cm.dynamic_shapes,
73-
clone=self.config.clone_state_dict,
74-
strict=self.config.strict,
75-
patch_list=self.config.patch_list,
76-
)
69+
if False:
70+
gm = torch_export_to_gm(
71+
model,
72+
args=cm.args,
73+
dynamic_shapes=cm.dynamic_shapes,
74+
clone=self.config.clone_state_dict,
75+
strict=self.config.strict,
76+
patch_list=self.config.patch_list,
77+
)
7778

7879
# this is a clean graph by definition since it was just exported
7980
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)

tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py

Lines changed: 77 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from transformers.models.llama4.modeling_llama4 import Llama4CausalLMOutputWithPast
99
from utils.llm_data import llm_models_root
1010

11-
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
11+
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
1212
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
1313

1414

@@ -101,6 +101,42 @@ def _vision_branch(inputs_embeds, pixel_values, input_ids):
101101

102102
return inputs_embeds.view(original_inputs_embeds_shape)
103103

104+
def _vision_branch2(inputs_embeds, pixel_values, input_ids):
105+
image_features = self.get_image_features(
106+
pixel_values=pixel_values,
107+
vision_feature_layer=vision_feature_layer,
108+
vision_feature_select_strategy=vision_feature_select_strategy,
109+
image_sizes=image_sizes,
110+
)
111+
112+
vision_flat = image_features.view(-1, image_features.size(-1))
113+
projected_vision_flat = self.multi_modal_projector(vision_flat).to(
114+
inputs_embeds.device, inputs_embeds.dtype
115+
)
116+
# NOTE: get_placeholder_mask is not supported by torch.export due to numel check ###########
117+
# special_image_mask = self.get_placeholder_mask(
118+
# input_ids, inputs_embeds=inputs_embeds, image_features=projected_vision_flat
119+
# )
120+
if input_ids is None:
121+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
122+
torch.tensor(
123+
self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device
124+
)
125+
)
126+
special_image_mask = special_image_mask.all(-1)
127+
else:
128+
special_image_mask = input_ids == self.config.image_token_id
129+
130+
n_image_tokens = special_image_mask.sum()
131+
special_image_mask = (
132+
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
133+
)
134+
### END OF get_placeholder_mask ############################################################
135+
136+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, projected_vision_flat)
137+
138+
return inputs_embeds
139+
104140
def _no_vision_branch(inputs_embeds, pixel_values, input_ids):
105141
return inputs_embeds
106142

@@ -109,7 +145,7 @@ def _no_vision_branch(inputs_embeds, pixel_values, input_ids):
109145

110146
inputs_embeds = torch.cond(
111147
has_image,
112-
_vision_branch,
148+
_vision_branch2,
113149
_no_vision_branch,
114150
(inputs_embeds, pixel_values, input_ids),
115151
)
@@ -132,7 +168,10 @@ def _no_vision_branch(inputs_embeds, pixel_values, input_ids):
132168

133169
loss = None
134170
if labels is not None:
171+
# Shift so that tokens < n predict n
135172
if attention_mask is not None:
173+
# we use the input attention mask to shift the logits and labels, because it is 2D.
174+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
136175
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
137176
shift_logits = logits[..., :-1, :][
138177
shift_attention_mask.to(logits.device) != 0
@@ -141,6 +180,7 @@ def _no_vision_branch(inputs_embeds, pixel_values, input_ids):
141180
else:
142181
shift_logits = logits[..., :-1, :].contiguous()
143182
shift_labels = labels[..., 1:].contiguous()
183+
# Flatten the tokens
144184
loss_fct = nn.CrossEntropyLoss()
145185
loss = loss_fct(
146186
shift_logits.view(-1, shift_logits.size(-1)),
@@ -210,35 +250,49 @@ def test_build_run_llama4_vlm():
210250
.to(torch.bfloat16)
211251
)
212252

213-
with torch.inference_mode():
214-
# the original model queried with text-only
215-
out_text_only = model(inputs["input_ids"], None, inputs["attention_mask"])
216-
253+
def _run_with_and_without_image(model, use_none=False):
254+
with apply_export_patches({"transformers_sdpa_mask": {}, "autocast_noop": {}}):
255+
with torch.inference_mode():
256+
out_no_images = model(
257+
inputs["input_ids"],
258+
None if use_none else torch.zeros_like(inputs["pixel_values"]),
259+
inputs["attention_mask"],
260+
)
261+
out_with_images = model(
262+
inputs["input_ids"],
263+
inputs["pixel_values"],
264+
inputs["attention_mask"],
265+
)
266+
return {"no_images": out_no_images.logits, "with_images": out_with_images.logits}
267+
268+
# Get output pre-patch
269+
out_original = _run_with_and_without_image(model, use_none=True)
270+
271+
# set patch
217272
Llama4ForConditionalGeneration.forward = _forward_with_cond
218273

219-
with torch.inference_mode():
220-
out_real = model(inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"])
221-
out_dummy = model(
222-
inputs["input_ids"], torch.zeros_like(inputs["pixel_values"]), inputs["attention_mask"]
223-
)
224-
torch.testing.assert_close(out_dummy.logits, out_text_only.logits, rtol=rtol, atol=atol)
274+
# Get output post-patch
275+
outputs_for_comparison = {}
276+
outputs_for_comparison["model_with_patch"] = _run_with_and_without_image(model)
225277

278+
# Export to GM
226279
gm = torch_export_to_gm(
227280
model,
228281
(inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"]),
229282
kwargs={},
283+
patch_list=["transformers_sdpa_mask", "autocast_noop"],
230284
)
231285
move_to_device(gm, model.device)
232286

233-
with torch.inference_mode():
234-
out_real_gm = gm(inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"])
235-
torch.testing.assert_close(out_real.logits, out_real_gm.logits, rtol=rtol, atol=atol)
236-
out_dummy_gm = gm(
237-
inputs["input_ids"], torch.zeros_like(inputs["pixel_values"]), inputs["attention_mask"]
238-
)
239-
torch.testing.assert_close(out_dummy.logits, out_dummy_gm.logits, rtol=rtol, atol=atol)
240-
torch.testing.assert_close(out_dummy_gm.logits, out_text_only.logits, rtol=rtol, atol=atol)
241-
242-
assert not torch.allclose(out_real.logits, out_dummy.logits, rtol=rtol, atol=atol), (
243-
"Expected outputs to differ between text only input and text+image input"
287+
# Get the output post export
288+
outputs_for_comparison["gm"] = _run_with_and_without_image(gm)
289+
290+
# Run comparisons to out_original with no patch now...
291+
for comp, outs in outputs_for_comparison.items():
292+
torch.testing.assert_close(
293+
outs,
294+
out_original,
295+
rtol=rtol,
296+
atol=atol,
297+
msg=lambda m: f"Comparison for {comp} failed:\n{m}",
244298
)

0 commit comments

Comments
 (0)