@@ -21,6 +21,9 @@ def __init__(self, model):
2121
2222    def  forward (self , pixel_values ):
2323        vision_embeds  =  self .model .extract_feature (pixel_values )
24+         # Reshape from [num_patches, 256, hidden_dim] -> [1, num_patches*256, head_dim] 
25+         # To enable prefill chunking for num_patches > 1 
26+         vision_embeds  =  vision_embeds .reshape (1 , - 1 , vision_embeds .shape [- 1 ])
2427        return  vision_embeds 
2528
2629
@@ -35,14 +38,22 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
3538        input_embeds  =  self .model .language_model .get_input_embeddings ()(input_ids )
3639        B , N , C  =  input_embeds .shape 
3740        image_input_embeds  =  input_embeds .reshape (B  *  N , C )
41+         input_embeds  =  input_embeds .reshape (B  *  N , C )
3842        image_input_ids  =  input_ids .reshape (B  *  N )
39-         selected  =  image_input_ids  ==  constants .INTERN_IMG_CONTEXT_TOKEN 
43+         # TODO: Find a better way to decide which token value to use 
44+         image_context_token  =  (
45+             constants .INTERN_3_5_IMG_CONTEXT_TOKEN 
46+             if  "Qwen3"  in  self .config .architectures [0 ]
47+             else  constants .INTERN_IMG_CONTEXT_TOKEN 
48+         )
49+         selected  =  image_input_ids  ==  image_context_token 
4050        indices1  =  selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) -  1 
4151        indices1  =  torch .where (indices1  !=  - 1 , indices1  +  image_idx , indices1 )
4252        indices0  =  torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
4353        image_features_expanded  =  vision_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
4454        image_input_embeds  =  torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
4555        inputs_embeds  =  torch .where (input_ids .shape [1 ] ==  torch .tensor (1 ), input_embeds , image_input_embeds )
56+         inputs_embeds  =  inputs_embeds .reshape (B , N , C )
4657        outputs  =  self .model .language_model (
4758            inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True 
4859        )
@@ -84,12 +95,13 @@ def get_specializations(
8495            raise  NotImplementedError ("Image Size other than 448 is not supported for Intern models yet." )
8596
8697        per_patch_embed_size  =  (img_size  //  self .config .vision_config .patch_size  *  self .config .downsample_ratio ) **  2 
87-         vision_size  =  int (num_patches  *  per_patch_embed_size )
98+         vision_size  =  int (batch_size   *   num_patches  *  per_patch_embed_size )
8899        vision  =  [
89100            {
90101                "batch_size" : batch_size ,
91102                "num_patches" : num_patches ,
92103                "img_size" : img_size ,
104+                 "batched_num_patches" : batch_size  *  num_patches ,
93105            }
94106        ]
95107        lang  =  [
@@ -126,8 +138,8 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
126138        lang_dynamic_axes  =  {}
127139        lang_dynamic_axes ["input_ids" ] =  {0 : "batch_size" , 1 : "seq_len" }
128140        lang_dynamic_axes ["position_ids" ] =  {0 : "batch_size" , 1 : "seq_len" }
129-         lang_dynamic_axes ["vision_embeds" ] =  {0 :  "batch_size" ,  1 : "vision_size" }
130-         vision_dynamic_axes ["pixel_values" ] =  {0 : "num_patches " , 2 : "img_size" , 3 : "img_size" }
141+         lang_dynamic_axes ["vision_embeds" ] =  {1 : "vision_size" }
142+         vision_dynamic_axes ["pixel_values" ] =  {0 : "batched_num_patches " , 2 : "img_size" , 3 : "img_size" }
131143
132144        pkv_dynamic_axes  =  {0 : "batch_size" , 2 : "ctx_len" }
133145        for  i  in  range (self .language_model .config .num_hidden_layers ):
@@ -182,16 +194,16 @@ def get_dummy_inputs(self, kv_offload: bool = False):
182194        inputs_shapes  =  {}
183195        inputs_shapes ["input_ids" ] =  (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
184196        inputs_shapes ["vision_embeds" ] =  (
185-             constants . ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
186-             computed_feature_size ,
197+             1 ,
198+             computed_feature_size   *   constants . ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
187199            self .language_model .config .hidden_size ,
188200        )
189201        inputs_shapes ["position_ids" ] =  (
190202            constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
191203            constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
192204        )
193205        inputs_shapes ["pixel_values" ] =  (
194-             constants .INTERN_NUM_PATCHES ,
206+             constants .INTERN_NUM_PATCHES   *   constants . ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
195207            constants .INTERN_NUM_CHANNELS ,
196208            img_size ,
197209            img_size ,
@@ -237,14 +249,22 @@ def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_val
237249        vision_embeds  =  self .extract_feature (pixel_values )
238250        B , N , C  =  input_embeds .shape 
239251        image_input_embeds  =  input_embeds .reshape (B  *  N , C )
252+         input_embeds  =  input_embeds .reshape (B  *  N , C )
240253        image_input_ids  =  input_ids .reshape (B  *  N )
241-         selected  =  image_input_ids  ==  constants .INTERN_IMG_CONTEXT_TOKEN 
254+         # TODO: Find a better way to decide which token value to use 
255+         image_context_token  =  (
256+             constants .INTERN_3_5_IMG_CONTEXT_TOKEN 
257+             if  "Qwen3"  in  self .config .architectures [0 ]
258+             else  constants .INTERN_IMG_CONTEXT_TOKEN 
259+         )
260+         selected  =  image_input_ids  ==  image_context_token 
242261        indices1  =  selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) -  1 
243262        indices1  =  torch .where (indices1  !=  - 1 , indices1  +  image_idx , indices1 )
244263        indices0  =  torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
245264        image_features_expanded  =  vision_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
246265        image_input_embeds  =  torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
247266        inputs_embeds  =  torch .where (input_ids .shape [1 ] ==  torch .tensor (1 ), input_embeds , image_input_embeds )
267+         inputs_embeds  =  inputs_embeds .reshape (B , N , C )
248268        outputs  =  self .language_model (
249269            inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True 
250270        )
0 commit comments