Skip to content

Commit 1272415

Browse files
Added support for InternVL_3_5 series of VLMs. (#566)
Added support for Qwen3ForCausalLM models, tested on Qwen3-0.6B model for CI runs. Updated modeling internvl script to allow proper prefix chunking of vision+embeds when more than 1 patches are needed. Test InternVL_3_5_1B model for 1 and full layers via CI. --------- Signed-off-by: quic-dhirajku <[email protected]>
1 parent a9e404a commit 1272415

File tree

9 files changed

+551
-199
lines changed

9 files changed

+551
-199
lines changed

QEfficient/transformers/models/internvl/modeling_internvl.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def __init__(self, model):
2121

2222
def forward(self, pixel_values):
2323
vision_embeds = self.model.extract_feature(pixel_values)
24+
# Reshape from [num_patches, 256, hidden_dim] -> [1, num_patches*256, head_dim]
25+
# To enable prefill chunking for num_patches > 1
26+
vision_embeds = vision_embeds.reshape(1, -1, vision_embeds.shape[-1])
2427
return vision_embeds
2528

2629

@@ -35,14 +38,22 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
3538
input_embeds = self.model.language_model.get_input_embeddings()(input_ids)
3639
B, N, C = input_embeds.shape
3740
image_input_embeds = input_embeds.reshape(B * N, C)
41+
input_embeds = input_embeds.reshape(B * N, C)
3842
image_input_ids = input_ids.reshape(B * N)
39-
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
43+
# TODO: Find a better way to decide which token value to use
44+
image_context_token = (
45+
constants.INTERN_3_5_IMG_CONTEXT_TOKEN
46+
if "Qwen3" in self.config.architectures[0]
47+
else constants.INTERN_IMG_CONTEXT_TOKEN
48+
)
49+
selected = image_input_ids == image_context_token
4050
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
4151
indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1)
4252
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
4353
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
4454
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
4555
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
56+
inputs_embeds = inputs_embeds.reshape(B, N, C)
4657
outputs = self.model.language_model(
4758
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
4859
)
@@ -84,12 +95,13 @@ def get_specializations(
8495
raise NotImplementedError("Image Size other than 448 is not supported for Intern models yet.")
8596

8697
per_patch_embed_size = (img_size // self.config.vision_config.patch_size * self.config.downsample_ratio) ** 2
87-
vision_size = int(num_patches * per_patch_embed_size)
98+
vision_size = int(batch_size * num_patches * per_patch_embed_size)
8899
vision = [
89100
{
90101
"batch_size": batch_size,
91102
"num_patches": num_patches,
92103
"img_size": img_size,
104+
"batched_num_patches": batch_size * num_patches,
93105
}
94106
]
95107
lang = [
@@ -126,8 +138,8 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
126138
lang_dynamic_axes = {}
127139
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
128140
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
129-
lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "vision_size"}
130-
vision_dynamic_axes["pixel_values"] = {0: "num_patches", 2: "img_size", 3: "img_size"}
141+
lang_dynamic_axes["vision_embeds"] = {1: "vision_size"}
142+
vision_dynamic_axes["pixel_values"] = {0: "batched_num_patches", 2: "img_size", 3: "img_size"}
131143

132144
pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
133145
for i in range(self.language_model.config.num_hidden_layers):
@@ -182,16 +194,16 @@ def get_dummy_inputs(self, kv_offload: bool = False):
182194
inputs_shapes = {}
183195
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
184196
inputs_shapes["vision_embeds"] = (
185-
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
186-
computed_feature_size,
197+
1,
198+
computed_feature_size * constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
187199
self.language_model.config.hidden_size,
188200
)
189201
inputs_shapes["position_ids"] = (
190202
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
191203
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
192204
)
193205
inputs_shapes["pixel_values"] = (
194-
constants.INTERN_NUM_PATCHES,
206+
constants.INTERN_NUM_PATCHES * constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
195207
constants.INTERN_NUM_CHANNELS,
196208
img_size,
197209
img_size,
@@ -237,14 +249,22 @@ def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_val
237249
vision_embeds = self.extract_feature(pixel_values)
238250
B, N, C = input_embeds.shape
239251
image_input_embeds = input_embeds.reshape(B * N, C)
252+
input_embeds = input_embeds.reshape(B * N, C)
240253
image_input_ids = input_ids.reshape(B * N)
241-
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
254+
# TODO: Find a better way to decide which token value to use
255+
image_context_token = (
256+
constants.INTERN_3_5_IMG_CONTEXT_TOKEN
257+
if "Qwen3" in self.config.architectures[0]
258+
else constants.INTERN_IMG_CONTEXT_TOKEN
259+
)
260+
selected = image_input_ids == image_context_token
242261
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
243262
indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1)
244263
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
245264
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
246265
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
247266
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
267+
inputs_embeds = inputs_embeds.reshape(B, N, C)
248268
outputs = self.language_model(
249269
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
250270
)

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@
139139
Qwen2Model,
140140
Qwen2RMSNorm,
141141
)
142+
from transformers.models.qwen3.modeling_qwen3 import (
143+
Qwen3Attention,
144+
Qwen3DecoderLayer,
145+
Qwen3ForCausalLM,
146+
Qwen3Model,
147+
Qwen3RMSNorm,
148+
)
142149
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
143150
Qwen3MoeAttention,
144151
Qwen3MoeDecoderLayer,
@@ -318,6 +325,12 @@
318325
QEffQwen2ForCausalLM,
319326
QEffQwen2Model,
320327
)
328+
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
329+
QEffQwen3Attention,
330+
QEffQwen3DecoderLayer,
331+
QEffQwen3ForCausalLM,
332+
QEffQwen3Model,
333+
)
321334
from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import (
322335
QEffQwen3MoeAttention,
323336
QEffQwen3MoeDecoderLayer,
@@ -358,6 +371,7 @@ class CustomOpsTransform(ModuleMappingTransform):
358371
MixtralRMSNorm: CustomRMSNormAIC,
359372
Phi3RMSNorm: CustomRMSNormAIC,
360373
Qwen2RMSNorm: CustomRMSNormAIC,
374+
Qwen3RMSNorm: CustomRMSNormAIC,
361375
MllamaTextRMSNorm: CustomRMSNormAIC,
362376
GraniteRMSNorm: CustomRMSNormAIC,
363377
GraniteMoeRMSNorm: CustomRMSNormAIC,
@@ -486,6 +500,11 @@ class KVCacheTransform(ModuleMappingTransform):
486500
Qwen2DecoderLayer: QEffQwen2DecoderLayer,
487501
Qwen2Model: QEffQwen2Model,
488502
Qwen2ForCausalLM: QEffQwen2ForCausalLM,
503+
# Qwen3
504+
Qwen3Attention: QEffQwen3Attention,
505+
Qwen3DecoderLayer: QEffQwen3DecoderLayer,
506+
Qwen3Model: QEffQwen3Model,
507+
Qwen3ForCausalLM: QEffQwen3ForCausalLM,
489508
# Starcoder2
490509
Starcoder2Attention: QEffStarcoder2Attention,
491510
Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer,
@@ -532,6 +551,7 @@ class SpDTransform:
532551
# Llama
533552
QEffLlamaForCausalLM,
534553
QEffQwen2ForCausalLM,
554+
QEffQwen3ForCausalLM,
535555
}
536556

537557
@classmethod
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)