Skip to content

Commit 504afd2

Browse files
committed
wip for llama4 patch fix
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 28c666e commit 504afd2

File tree

3 files changed

+46
-9
lines changed

3 files changed

+46
-9
lines changed

tensorrt_llm/_torch/auto_deploy/llm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ def __call__(
6666
# TODO: can we avoid the extra tolist() here eventually?
6767
token_ids = all_args.pop("input_ids")
6868
assert token_ids.shape[0] == 1, "messages should be unbatched at this point."
69-
return token_ids[0].tolist(), {"multimodal_data": all_args} if all_args else None
69+
if all_args:
70+
extra_processed_inputs = {"multimodal_data": all_args}
71+
else:
72+
extra_processed_inputs = None
73+
return token_ids[0].tolist(), extra_processed_inputs
7074
else:
7175
token_ids = self.tokenizer.encode(inputs["prompt"], **kwargs)
7276
return token_ids, None

tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,15 @@ def _apply(
6666
cm.info.set_example_sequence(**factory.get_example_inputs())
6767

6868
# export the model to a graph module
69-
gm = torch_export_to_gm(
70-
model,
71-
args=cm.args,
72-
dynamic_shapes=cm.dynamic_shapes,
73-
clone=self.config.clone_state_dict,
74-
strict=self.config.strict,
75-
patch_list=self.config.patch_list,
76-
)
69+
if False:
70+
gm = torch_export_to_gm(
71+
model,
72+
args=cm.args,
73+
dynamic_shapes=cm.dynamic_shapes,
74+
clone=self.config.clone_state_dict,
75+
strict=self.config.strict,
76+
patch_list=self.config.patch_list,
77+
)
7778

7879
# this is a clean graph by definition since it was just exported
7980
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)

tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,38 @@ def _vision_branch(inputs_embeds, pixel_values, input_ids):
101101

102102
return inputs_embeds.view(original_inputs_embeds_shape)
103103

104+
def _vision_branch2(inputs_embeds, pixel_values, input_ids):
105+
image_features = self.get_image_features(
106+
pixel_values=pixel_values,
107+
vision_feature_layer=vision_feature_layer,
108+
vision_feature_select_strategy=vision_feature_select_strategy,
109+
image_sizes=None,
110+
)
111+
original_inputs_embeds_shape = inputs_embeds.shape
112+
113+
vision_flat = image_features.view(-1, image_features.size(-1))
114+
projected_vision_flat = self.multi_modal_projector(vision_flat)
115+
116+
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
117+
final_mask = special_image_mask.to(inputs_embeds.device)
118+
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
119+
120+
final_mask_1d = final_mask[..., 0].reshape(-1)
121+
# num_tokens_to_fill = final_mask_1d.sum()
122+
123+
# This condition statement breaks torch.export:
124+
# TODO: sanity check on the inputs for this
125+
# if num_tokens_to_fill != projected_vision_flat.size(0):
126+
# raise ValueError(
127+
# f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
128+
# f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
129+
# )
130+
131+
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
132+
inputs_embeds.masked_scatter_(expanded_mask, projected_vision_flat)
133+
134+
return inputs_embeds.view(original_inputs_embeds_shape)
135+
104136
def _no_vision_branch(inputs_embeds, pixel_values, input_ids):
105137
return inputs_embeds
106138

0 commit comments

Comments
 (0)