8
8
from transformers .models .llama4 .modeling_llama4 import Llama4CausalLMOutputWithPast
9
9
from utils .llm_data import llm_models_root
10
10
11
- from tensorrt_llm ._torch .auto_deploy .export import torch_export_to_gm
11
+ from tensorrt_llm ._torch .auto_deploy .export import apply_export_patches , torch_export_to_gm
12
12
from tensorrt_llm ._torch .auto_deploy .transformations ._graph import move_to_device
13
13
14
14
@@ -101,6 +101,42 @@ def _vision_branch(inputs_embeds, pixel_values, input_ids):
101
101
102
102
return inputs_embeds .view (original_inputs_embeds_shape )
103
103
104
+ def _vision_branch2 (inputs_embeds , pixel_values , input_ids ):
105
+ image_features = self .get_image_features (
106
+ pixel_values = pixel_values ,
107
+ vision_feature_layer = vision_feature_layer ,
108
+ vision_feature_select_strategy = vision_feature_select_strategy ,
109
+ image_sizes = image_sizes ,
110
+ )
111
+
112
+ vision_flat = image_features .view (- 1 , image_features .size (- 1 ))
113
+ projected_vision_flat = self .multi_modal_projector (vision_flat ).to (
114
+ inputs_embeds .device , inputs_embeds .dtype
115
+ )
116
+ # NOTE: get_placeholder_mask is not supported by torch.export due to numel check ###########
117
+ # special_image_mask = self.get_placeholder_mask(
118
+ # input_ids, inputs_embeds=inputs_embeds, image_features=projected_vision_flat
119
+ # )
120
+ if input_ids is None :
121
+ special_image_mask = inputs_embeds == self .get_input_embeddings ()(
122
+ torch .tensor (
123
+ self .config .image_token_id , dtype = torch .long , device = inputs_embeds .device
124
+ )
125
+ )
126
+ special_image_mask = special_image_mask .all (- 1 )
127
+ else :
128
+ special_image_mask = input_ids == self .config .image_token_id
129
+
130
+ n_image_tokens = special_image_mask .sum ()
131
+ special_image_mask = (
132
+ special_image_mask .unsqueeze (- 1 ).expand_as (inputs_embeds ).to (inputs_embeds .device )
133
+ )
134
+ ### END OF get_placeholder_mask ############################################################
135
+
136
+ inputs_embeds = inputs_embeds .masked_scatter (special_image_mask , projected_vision_flat )
137
+
138
+ return inputs_embeds
139
+
104
140
def _no_vision_branch (inputs_embeds , pixel_values , input_ids ):
105
141
return inputs_embeds
106
142
@@ -109,7 +145,7 @@ def _no_vision_branch(inputs_embeds, pixel_values, input_ids):
109
145
110
146
inputs_embeds = torch .cond (
111
147
has_image ,
112
- _vision_branch ,
148
+ _vision_branch2 ,
113
149
_no_vision_branch ,
114
150
(inputs_embeds , pixel_values , input_ids ),
115
151
)
@@ -132,7 +168,10 @@ def _no_vision_branch(inputs_embeds, pixel_values, input_ids):
132
168
133
169
loss = None
134
170
if labels is not None :
171
+ # Shift so that tokens < n predict n
135
172
if attention_mask is not None :
173
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
174
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
136
175
shift_attention_mask = attention_mask [:, - (logits .shape [1 ] - 1 ) :].to (logits .device )
137
176
shift_logits = logits [..., :- 1 , :][
138
177
shift_attention_mask .to (logits .device ) != 0
@@ -141,6 +180,7 @@ def _no_vision_branch(inputs_embeds, pixel_values, input_ids):
141
180
else :
142
181
shift_logits = logits [..., :- 1 , :].contiguous ()
143
182
shift_labels = labels [..., 1 :].contiguous ()
183
+ # Flatten the tokens
144
184
loss_fct = nn .CrossEntropyLoss ()
145
185
loss = loss_fct (
146
186
shift_logits .view (- 1 , shift_logits .size (- 1 )),
@@ -210,35 +250,49 @@ def test_build_run_llama4_vlm():
210
250
.to (torch .bfloat16 )
211
251
)
212
252
213
- with torch .inference_mode ():
214
- # the original model queried with text-only
215
- out_text_only = model (inputs ["input_ids" ], None , inputs ["attention_mask" ])
216
-
253
+ def _run_with_and_without_image (model , use_none = False ):
254
+ with apply_export_patches ({"transformers_sdpa_mask" : {}, "autocast_noop" : {}}):
255
+ with torch .inference_mode ():
256
+ out_no_images = model (
257
+ inputs ["input_ids" ],
258
+ None if use_none else torch .zeros_like (inputs ["pixel_values" ]),
259
+ inputs ["attention_mask" ],
260
+ )
261
+ out_with_images = model (
262
+ inputs ["input_ids" ],
263
+ inputs ["pixel_values" ],
264
+ inputs ["attention_mask" ],
265
+ )
266
+ return {"no_images" : out_no_images .logits , "with_images" : out_with_images .logits }
267
+
268
+ # Get output pre-patch
269
+ out_original = _run_with_and_without_image (model , use_none = True )
270
+
271
+ # set patch
217
272
Llama4ForConditionalGeneration .forward = _forward_with_cond
218
273
219
- with torch .inference_mode ():
220
- out_real = model (inputs ["input_ids" ], inputs ["pixel_values" ], inputs ["attention_mask" ])
221
- out_dummy = model (
222
- inputs ["input_ids" ], torch .zeros_like (inputs ["pixel_values" ]), inputs ["attention_mask" ]
223
- )
224
- torch .testing .assert_close (out_dummy .logits , out_text_only .logits , rtol = rtol , atol = atol )
274
+ # Get output post-patch
275
+ outputs_for_comparison = {}
276
+ outputs_for_comparison ["model_with_patch" ] = _run_with_and_without_image (model )
225
277
278
+ # Export to GM
226
279
gm = torch_export_to_gm (
227
280
model ,
228
281
(inputs ["input_ids" ], inputs ["pixel_values" ], inputs ["attention_mask" ]),
229
282
kwargs = {},
283
+ patch_list = ["transformers_sdpa_mask" , "autocast_noop" ],
230
284
)
231
285
move_to_device (gm , model .device )
232
286
233
- with torch . inference_mode ():
234
- out_real_gm = gm ( inputs [ "input_ids" ], inputs [ "pixel_values" ], inputs [ "attention_mask" ] )
235
- torch . testing . assert_close ( out_real . logits , out_real_gm . logits , rtol = rtol , atol = atol )
236
- out_dummy_gm = gm (
237
- inputs [ "input_ids" ], torch . zeros_like ( inputs [ "pixel_values" ]), inputs [ "attention_mask" ]
238
- )
239
- torch . testing . assert_close ( out_dummy . logits , out_dummy_gm . logits , rtol = rtol , atol = atol )
240
- torch . testing . assert_close ( out_dummy_gm . logits , out_text_only . logits , rtol = rtol , atol = atol )
241
-
242
- assert not torch . allclose ( out_real . logits , out_dummy . logits , rtol = rtol , atol = atol ), (
243
- "Expected outputs to differ between text only input and text+image input"
287
+ # Get the output post export
288
+ outputs_for_comparison [ "gm" ] = _run_with_and_without_image ( gm )
289
+
290
+ # Run comparisons to out_original with no patch now...
291
+ for comp , outs in outputs_for_comparison . items ():
292
+ torch . testing . assert_close (
293
+ outs ,
294
+ out_original ,
295
+ rtol = rtol ,
296
+ atol = atol ,
297
+ msg = lambda m : f"Comparison for { comp } failed: \n { m } " ,
244
298
)
0 commit comments