|
5 | 5 | import traceback
|
6 | 6 |
|
7 | 7 | import numpy as np
|
8 |
| -import requests |
9 | 8 | from absl import app
|
10 | 9 | from absl import flags
|
11 | 10 | from keras import ops
|
12 | 11 | from transformers import AutoTokenizer
|
13 | 12 | from transformers import MistralForCausalLM
|
14 | 13 |
|
15 | 14 | from keras_hub.models import MistralBackbone
|
| 15 | +from keras_hub.models import MistralCausalLM |
16 | 16 | from keras_hub.models import MistralCausalLMPreprocessor
|
17 | 17 | from keras_hub.models import MistralTokenizer
|
18 |
| -from keras_hub.utils.preset_utils import save_to_preset |
19 | 18 |
|
20 | 19 | PRESET_MAP = {
|
21 | 20 | "mistral_7b_en": "mistralai/Mistral-7B-v0.1",
|
| 21 | + "mistral_0.3_7b_en": "mistralai/Mistral-7B-v0.3", |
22 | 22 | "mistral_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.1",
|
23 | 23 | "mistral_0.2_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.2",
|
| 24 | + "mistral_0.3_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.3", |
24 | 25 | }
|
25 | 26 |
|
26 | 27 | FLAGS = flags.FLAGS
|
@@ -236,49 +237,43 @@ def main(_):
|
236 | 237 | rope_max_wavelength=hf_model.config.rope_theta,
|
237 | 238 | dtype="float32",
|
238 | 239 | )
|
239 |
| - keras_hub_model = MistralBackbone(**backbone_kwargs) |
| 240 | + keras_hub_backbone = MistralBackbone(**backbone_kwargs) |
240 | 241 |
|
241 |
| - # === Download the tokenizer from Huggingface model card === |
242 |
| - spm_path = ( |
243 |
| - f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" |
244 |
| - ) |
245 |
| - response = requests.get(spm_path) |
246 |
| - if not response.ok: |
247 |
| - raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") |
248 |
| - tokenizer_path = os.path.join(temp_dir, "vocabulary.spm") |
249 |
| - with open(tokenizer_path, "wb") as tokenizer_file: |
250 |
| - tokenizer_file.write(response.content) |
251 |
| - keras_hub_tokenizer = MistralTokenizer(tokenizer_path) |
| 242 | + keras_hub_tokenizer = MistralTokenizer.from_preset(f"hf://{hf_preset}") |
252 | 243 | print("\n-> Keras 3 model and tokenizer loaded.")
|
253 | 244 |
|
254 | 245 | # === Port the weights ===
|
255 |
| - convert_checkpoints(keras_hub_model, hf_model) |
| 246 | + convert_checkpoints(keras_hub_backbone, hf_model) |
256 | 247 | print("\n-> Weight transfer done.")
|
257 | 248 |
|
258 | 249 | # === Check that the models and tokenizers outputs match ===
|
259 | 250 | test_tokenizer(keras_hub_tokenizer, hf_tokenizer)
|
260 |
| - test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) |
| 251 | + test_model( |
| 252 | + keras_hub_backbone, keras_hub_tokenizer, hf_model, hf_tokenizer |
| 253 | + ) |
261 | 254 | print("\n-> Tests passed!")
|
262 | 255 |
|
263 | 256 | # === Save the model weights in float32 format ===
|
264 |
| - keras_hub_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) |
| 257 | + keras_hub_backbone.save_weights( |
| 258 | + os.path.join(temp_dir, "model.weights.h5") |
| 259 | + ) |
265 | 260 | print("\n-> Saved the model weights in float32")
|
266 | 261 |
|
267 |
| - del keras_hub_model, hf_model |
| 262 | + del keras_hub_backbone, hf_model |
268 | 263 | gc.collect()
|
269 | 264 |
|
270 | 265 | # === Save the weights again in float16 ===
|
271 | 266 | backbone_kwargs["dtype"] = "float16"
|
272 |
| - keras_hub_model = MistralBackbone(**backbone_kwargs) |
273 |
| - keras_hub_model.load_weights(os.path.join(temp_dir, "model.weights.h5")) |
274 |
| - save_to_preset(keras_hub_model, preset) |
| 267 | + keras_hub_backbone = MistralBackbone(**backbone_kwargs) |
| 268 | + keras_hub_backbone.load_weights( |
| 269 | + os.path.join(temp_dir, "model.weights.h5") |
| 270 | + ) |
| 271 | + |
| 272 | + preprocessor = MistralCausalLMPreprocessor(keras_hub_tokenizer) |
| 273 | + keras_hub_model = MistralCausalLM(keras_hub_backbone, preprocessor) |
| 274 | + keras_hub_model.save_to_preset(f"./{preset}") |
275 | 275 | print("\n-> Saved the model preset in float16")
|
276 | 276 |
|
277 |
| - # === Save the tokenizer === |
278 |
| - save_to_preset( |
279 |
| - keras_hub_tokenizer, preset, config_filename="tokenizer.json" |
280 |
| - ) |
281 |
| - print("\n-> Saved the tokenizer") |
282 | 277 | finally:
|
283 | 278 | shutil.rmtree(temp_dir)
|
284 | 279 |
|
|
0 commit comments