|
| 1 | +import copy |
1 | 2 | import dataclasses
|
2 | 3 | import os
|
3 | 4 | from typing import List, Optional, Tuple
|
|
7 | 8 | from transformers.modeling_utils import no_init_weights
|
8 | 9 | from transformers.models.gemma3.modeling_gemma3 import Gemma3MultiModalProjector
|
9 | 10 |
|
| 11 | +from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ |
| 12 | + BaseWeightMapper |
| 13 | + |
10 | 14 | from ..._utils import nvtx_range
|
11 | 15 | from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
12 | 16 | register_input_processor)
|
@@ -98,13 +102,14 @@ def __init__(self, model_config: ModelConfig[Gemma3Config]):
|
98 | 102 | dtype=torch.int32,
|
99 | 103 | device=self._device)
|
100 | 104 |
|
101 |
| - self.model_config = model_config |
| 105 | + model_config_cp = copy.deepcopy(model_config) |
| 106 | + self.model_config = model_config_cp |
102 | 107 |
|
103 |
| - llm_model_config = self.get_sub_model_config(model_config, |
| 108 | + llm_model_config = self.get_sub_model_config(model_config_cp, |
104 | 109 | "text_config")
|
105 | 110 | self.llm = Gemma3ForCausalLM(llm_model_config)
|
106 | 111 |
|
107 |
| - vision_model_config = self.get_sub_model_config(model_config, |
| 112 | + vision_model_config = self.get_sub_model_config(model_config_cp, |
108 | 113 | "vision_config")
|
109 | 114 | self.siglip_tower = SiglipVisionModel(vision_model_config,
|
110 | 115 | use_post_layernorm=True)
|
@@ -141,9 +146,9 @@ def get_sub_model_config(
|
141 | 146 | sub_model_config.pretrained_config.torch_dtype = model_config.pretrained_config.torch_dtype
|
142 | 147 | return sub_model_config
|
143 | 148 |
|
144 |
| - def load_weights(self, weights): |
| 149 | + def load_weights(self, weights, weight_mapper: BaseWeightMapper): |
145 | 150 | llm_weights = filter_weights("language_model", weights)
|
146 |
| - self.llm.load_weights(llm_weights) |
| 151 | + self.llm.load_weights(llm_weights, weight_mapper) |
147 | 152 |
|
148 | 153 | vit_weights = filter_weights("vision_tower", weights)
|
149 | 154 | self.siglip_tower.load_weights(vit_weights)
|
|
0 commit comments