Skip to content

Commit 25c9062

Browse files
laxmareddyplaxmareddyp
andauthored
Fix Mistral conversion script (#2306)
* Fix Mistral conversion script This commit addresses several issues in the Mistral checkpoint conversion script: - Adds `dropout` to the model initialization to match the Hugging Face model. - Replaces `requests.get` with `hf_hub_download` for more reliable tokenizer downloads. - Adds support for both `tokenizer.model` and `tokenizer.json` to handle different Mistral versions. - Fixes a `TypeError` in the `save_to_preset` function call. * address format issues * adopted to latest hub style * address format issues --------- Co-authored-by: laxmareddyp <laxmareddyp@laxma-n2-highmem-256gbram.us-central1-f.c.gtech-rmi-dev.internal>
1 parent c7fa2c9 commit 25c9062

File tree

1 file changed

+21
-26
lines changed

1 file changed

+21
-26
lines changed

tools/checkpoint_conversion/convert_mistral_checkpoints.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,23 @@
55
import traceback
66

77
import numpy as np
8-
import requests
98
from absl import app
109
from absl import flags
1110
from keras import ops
1211
from transformers import AutoTokenizer
1312
from transformers import MistralForCausalLM
1413

1514
from keras_hub.models import MistralBackbone
15+
from keras_hub.models import MistralCausalLM
1616
from keras_hub.models import MistralCausalLMPreprocessor
1717
from keras_hub.models import MistralTokenizer
18-
from keras_hub.utils.preset_utils import save_to_preset
1918

2019
PRESET_MAP = {
2120
"mistral_7b_en": "mistralai/Mistral-7B-v0.1",
21+
"mistral_0.3_7b_en": "mistralai/Mistral-7B-v0.3",
2222
"mistral_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.1",
2323
"mistral_0.2_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.2",
24+
"mistral_0.3_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.3",
2425
}
2526

2627
FLAGS = flags.FLAGS
@@ -236,49 +237,43 @@ def main(_):
236237
rope_max_wavelength=hf_model.config.rope_theta,
237238
dtype="float32",
238239
)
239-
keras_hub_model = MistralBackbone(**backbone_kwargs)
240+
keras_hub_backbone = MistralBackbone(**backbone_kwargs)
240241

241-
# === Download the tokenizer from Huggingface model card ===
242-
spm_path = (
243-
f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model"
244-
)
245-
response = requests.get(spm_path)
246-
if not response.ok:
247-
raise ValueError(f"Couldn't fetch {preset}'s tokenizer.")
248-
tokenizer_path = os.path.join(temp_dir, "vocabulary.spm")
249-
with open(tokenizer_path, "wb") as tokenizer_file:
250-
tokenizer_file.write(response.content)
251-
keras_hub_tokenizer = MistralTokenizer(tokenizer_path)
242+
keras_hub_tokenizer = MistralTokenizer.from_preset(f"hf://{hf_preset}")
252243
print("\n-> Keras 3 model and tokenizer loaded.")
253244

254245
# === Port the weights ===
255-
convert_checkpoints(keras_hub_model, hf_model)
246+
convert_checkpoints(keras_hub_backbone, hf_model)
256247
print("\n-> Weight transfer done.")
257248

258249
# === Check that the models and tokenizers outputs match ===
259250
test_tokenizer(keras_hub_tokenizer, hf_tokenizer)
260-
test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer)
251+
test_model(
252+
keras_hub_backbone, keras_hub_tokenizer, hf_model, hf_tokenizer
253+
)
261254
print("\n-> Tests passed!")
262255

263256
# === Save the model weights in float32 format ===
264-
keras_hub_model.save_weights(os.path.join(temp_dir, "model.weights.h5"))
257+
keras_hub_backbone.save_weights(
258+
os.path.join(temp_dir, "model.weights.h5")
259+
)
265260
print("\n-> Saved the model weights in float32")
266261

267-
del keras_hub_model, hf_model
262+
del keras_hub_backbone, hf_model
268263
gc.collect()
269264

270265
# === Save the weights again in float16 ===
271266
backbone_kwargs["dtype"] = "float16"
272-
keras_hub_model = MistralBackbone(**backbone_kwargs)
273-
keras_hub_model.load_weights(os.path.join(temp_dir, "model.weights.h5"))
274-
save_to_preset(keras_hub_model, preset)
267+
keras_hub_backbone = MistralBackbone(**backbone_kwargs)
268+
keras_hub_backbone.load_weights(
269+
os.path.join(temp_dir, "model.weights.h5")
270+
)
271+
272+
preprocessor = MistralCausalLMPreprocessor(keras_hub_tokenizer)
273+
keras_hub_model = MistralCausalLM(keras_hub_backbone, preprocessor)
274+
keras_hub_model.save_to_preset(f"./{preset}")
275275
print("\n-> Saved the model preset in float16")
276276

277-
# === Save the tokenizer ===
278-
save_to_preset(
279-
keras_hub_tokenizer, preset, config_filename="tokenizer.json"
280-
)
281-
print("\n-> Saved the tokenizer")
282277
finally:
283278
shutil.rmtree(temp_dir)
284279

0 commit comments

Comments
 (0)