Skip to content

Commit a254a80

Browse files
johncalespNVShreyas
authored andcommitted
[Issue 6193] Fix gemma3vl weight loader (NVIDIA#6233)
Signed-off-by: John Calderon <[email protected]> Signed-off-by: Shreyas Misra <[email protected]>
1 parent 49b3b6c commit a254a80

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

tensorrt_llm/_torch/models/checkpoints/hf/gemma3_weight_mapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
@register_mapper("HF", "Gemma3ForCausalLM")
9+
@register_mapper("HF", "Gemma3ForConditionalGeneration")
910
class Gemma3HfWeightMapper(HfWeightMapper):
1011

1112
def should_skip_module(self, module_name: str) -> bool:

tensorrt_llm/_torch/models/modeling_gemma3vl.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import dataclasses
23
import os
34
from typing import List, Optional, Tuple
@@ -7,6 +8,9 @@
78
from transformers.modeling_utils import no_init_weights
89
from transformers.models.gemma3.modeling_gemma3 import Gemma3MultiModalProjector
910

11+
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
12+
BaseWeightMapper
13+
1014
from ..._utils import nvtx_range
1115
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
1216
register_input_processor)
@@ -98,13 +102,14 @@ def __init__(self, model_config: ModelConfig[Gemma3Config]):
98102
dtype=torch.int32,
99103
device=self._device)
100104

101-
self.model_config = model_config
105+
model_config_cp = copy.deepcopy(model_config)
106+
self.model_config = model_config_cp
102107

103-
llm_model_config = self.get_sub_model_config(model_config,
108+
llm_model_config = self.get_sub_model_config(model_config_cp,
104109
"text_config")
105110
self.llm = Gemma3ForCausalLM(llm_model_config)
106111

107-
vision_model_config = self.get_sub_model_config(model_config,
112+
vision_model_config = self.get_sub_model_config(model_config_cp,
108113
"vision_config")
109114
self.siglip_tower = SiglipVisionModel(vision_model_config,
110115
use_post_layernorm=True)
@@ -141,9 +146,9 @@ def get_sub_model_config(
141146
sub_model_config.pretrained_config.torch_dtype = model_config.pretrained_config.torch_dtype
142147
return sub_model_config
143148

144-
def load_weights(self, weights):
149+
def load_weights(self, weights, weight_mapper: BaseWeightMapper):
145150
llm_weights = filter_weights("language_model", weights)
146-
self.llm.load_weights(llm_weights)
151+
self.llm.load_weights(llm_weights, weight_mapper)
147152

148153
vit_weights = filter_weights("vision_tower", weights)
149154
self.siglip_tower.load_weights(vit_weights)

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ l0_h100:
7575
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
7676
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test
7777
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
78+
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
7879
- condition:
7980
ranges:
8081
system_gpu_count:
@@ -193,7 +194,6 @@ l0_h100:
193194
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype
194195
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
195196
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
196-
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
197197
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
198198
- condition:
199199
ranges:

0 commit comments

Comments
 (0)