Skip to content

Commit cdb5f76

Browse files
committed
Merge branch 'stablelm' of github.com:Bond099/keras-hub into stablelm
2 parents 05b1f0e + 652e525 commit cdb5f76

File tree

8 files changed

+51
-37
lines changed

8 files changed

+51
-37
lines changed

keras_hub/src/layers/modeling/transformer_encoder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@ class TransformerEncoder(keras.layers.Layer):
1616
paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
1717
can instantiate multiple instances of this class to stack up an encoder.
1818
19-
This layer will correctly compute an attention mask from an implicit
20-
Keras padding mask (for example, by passing `mask_zero=True` to a
21-
`keras.layers.Embedding` layer). See the Masking and Padding
19+
This layer will compute an attention mask, prioritizing explicitly provided
20+
masks (a `padding_mask` or a custom `attention_mask`) over an implicit Keras
21+
padding mask (for example, by passing `mask_zero=True` to a
22+
`keras.layers.Embedding` layer). If both a `padding_mask` and a
23+
`attention_mask` are provided, they will be combined to determine the final
24+
mask. See the Masking and Padding
2225
[guide](https://keras.io/guides/understanding_masking_and_padding/)
2326
for more details.
2427

keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def call(self, image_embeddings, text_embeddings, vision_indices):
6565
to_add = ops.multiply(
6666
keras.ops.arange(batch_size, dtype="int32"), seq_length
6767
)
68-
to_add = ops.expand_dims(to_add, axis=-1)
68+
to_add = ops.cast(ops.expand_dims(to_add, axis=-1), "int32")
6969
vision_indices = ops.add(vision_indices, to_add)
7070

7171
# indices should be of shape `(num_updates, 1)`. `num_updates` is

keras_hub/src/models/mistral/mistral_presets.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
},
1111
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/8",
1212
},
13+
"mistral_0.3_7b_en": {
14+
"metadata": {
15+
"description": "Mistral 7B base version 0.3 model",
16+
"params": 7248023552,
17+
"path": "mistral",
18+
},
19+
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.3_7b_en/1",
20+
},
1321
"mistral_instruct_7b_en": {
1422
"metadata": {
1523
"description": "Mistral 7B instruct model",
@@ -20,10 +28,18 @@
2028
},
2129
"mistral_0.2_instruct_7b_en": {
2230
"metadata": {
23-
"description": "Mistral 7B instruct Version 0.2 model",
31+
"description": "Mistral 7B instruct version 0.2 model",
2432
"params": 7241732096,
2533
"path": "mistral",
2634
},
2735
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/3",
2836
},
37+
"mistral_0.3_instruct_7b_en": {
38+
"metadata": {
39+
"description": "Mistral 7B instruct version 0.3 model",
40+
"params": 7248023552,
41+
"path": "mistral",
42+
},
43+
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.3_instruct_7b_en/1",
44+
},
2945
}

keras_hub/src/models/mixtral/mixtral_presets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"params": 46702792704,
1111
"path": "mixtral",
1212
},
13-
"kaggle_handle": "kaggle://keras/mixtral/keras/mixtral_8_7b_en/3",
13+
"kaggle_handle": "kaggle://keras/mixtral/keras/mixtral_8_7b_en/4",
1414
},
1515
"mixtral_8_instruct_7b_en": {
1616
"metadata": {
@@ -21,6 +21,6 @@
2121
"params": 46702792704,
2222
"path": "mixtral",
2323
},
24-
"kaggle_handle": "kaggle://keras/mixtral/keras/mixtral_8_instruct_7b_en/3",
24+
"kaggle_handle": "kaggle://keras/mixtral/keras/mixtral_8_instruct_7b_en/4",
2525
},
2626
}

keras_hub/src/models/qwen3/qwen3_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def _compute_attention(
299299
attention_scores,
300300
ops.cast(self._inv_norm_factor, self.compute_dtype),
301301
)
302-
if not self.sliding_window_size:
302+
if self.sliding_window_size:
303303
attention_mask = self._mask_sliding_window(
304304
attention_mask,
305305
cache_update_index=cache_update_index

keras_hub/src/models/qwen_moe/qwen_moe_presets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
"params": 14315784192,
1111
"path": "qwen-1.5-moe",
1212
},
13-
"kaggle_handle": "kaggle://keras/qwen-1.5-moe/Keras/qwen1.5_moe_2.7b_en/3",
13+
"kaggle_handle": "kaggle://keras/qwen-1.5-moe/Keras/qwen1.5_moe_2.7b_en/4",
1414
},
1515
}

requirements-torch-cuda.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ tensorflow-text~=2.18
44

55
# Torch with cuda support.
66
--extra-index-url https://download.pytorch.org/whl/cu126
7-
torch==2.6.0+cu126
8-
torchvision==0.21.0+cu126
7+
torch==2.7.0+cu126
8+
torchvision==0.22.0+cu126
99

1010
# Jax cpu-only version.
1111
jax[cpu]

tools/checkpoint_conversion/convert_mistral_checkpoints.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,23 @@
55
import traceback
66

77
import numpy as np
8-
import requests
98
from absl import app
109
from absl import flags
1110
from keras import ops
1211
from transformers import AutoTokenizer
1312
from transformers import MistralForCausalLM
1413

1514
from keras_hub.models import MistralBackbone
15+
from keras_hub.models import MistralCausalLM
1616
from keras_hub.models import MistralCausalLMPreprocessor
1717
from keras_hub.models import MistralTokenizer
18-
from keras_hub.utils.preset_utils import save_to_preset
1918

2019
PRESET_MAP = {
2120
"mistral_7b_en": "mistralai/Mistral-7B-v0.1",
21+
"mistral_0.3_7b_en": "mistralai/Mistral-7B-v0.3",
2222
"mistral_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.1",
2323
"mistral_0.2_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.2",
24+
"mistral_0.3_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.3",
2425
}
2526

2627
FLAGS = flags.FLAGS
@@ -236,49 +237,43 @@ def main(_):
236237
rope_max_wavelength=hf_model.config.rope_theta,
237238
dtype="float32",
238239
)
239-
keras_hub_model = MistralBackbone(**backbone_kwargs)
240+
keras_hub_backbone = MistralBackbone(**backbone_kwargs)
240241

241-
# === Download the tokenizer from Huggingface model card ===
242-
spm_path = (
243-
f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model"
244-
)
245-
response = requests.get(spm_path)
246-
if not response.ok:
247-
raise ValueError(f"Couldn't fetch {preset}'s tokenizer.")
248-
tokenizer_path = os.path.join(temp_dir, "vocabulary.spm")
249-
with open(tokenizer_path, "wb") as tokenizer_file:
250-
tokenizer_file.write(response.content)
251-
keras_hub_tokenizer = MistralTokenizer(tokenizer_path)
242+
keras_hub_tokenizer = MistralTokenizer.from_preset(f"hf://{hf_preset}")
252243
print("\n-> Keras 3 model and tokenizer loaded.")
253244

254245
# === Port the weights ===
255-
convert_checkpoints(keras_hub_model, hf_model)
246+
convert_checkpoints(keras_hub_backbone, hf_model)
256247
print("\n-> Weight transfer done.")
257248

258249
# === Check that the models and tokenizers outputs match ===
259250
test_tokenizer(keras_hub_tokenizer, hf_tokenizer)
260-
test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer)
251+
test_model(
252+
keras_hub_backbone, keras_hub_tokenizer, hf_model, hf_tokenizer
253+
)
261254
print("\n-> Tests passed!")
262255

263256
# === Save the model weights in float32 format ===
264-
keras_hub_model.save_weights(os.path.join(temp_dir, "model.weights.h5"))
257+
keras_hub_backbone.save_weights(
258+
os.path.join(temp_dir, "model.weights.h5")
259+
)
265260
print("\n-> Saved the model weights in float32")
266261

267-
del keras_hub_model, hf_model
262+
del keras_hub_backbone, hf_model
268263
gc.collect()
269264

270265
# === Save the weights again in float16 ===
271266
backbone_kwargs["dtype"] = "float16"
272-
keras_hub_model = MistralBackbone(**backbone_kwargs)
273-
keras_hub_model.load_weights(os.path.join(temp_dir, "model.weights.h5"))
274-
save_to_preset(keras_hub_model, preset)
267+
keras_hub_backbone = MistralBackbone(**backbone_kwargs)
268+
keras_hub_backbone.load_weights(
269+
os.path.join(temp_dir, "model.weights.h5")
270+
)
271+
272+
preprocessor = MistralCausalLMPreprocessor(keras_hub_tokenizer)
273+
keras_hub_model = MistralCausalLM(keras_hub_backbone, preprocessor)
274+
keras_hub_model.save_to_preset(f"./{preset}")
275275
print("\n-> Saved the model preset in float16")
276276

277-
# === Save the tokenizer ===
278-
save_to_preset(
279-
keras_hub_tokenizer, preset, config_filename="tokenizer.json"
280-
)
281-
print("\n-> Saved the tokenizer")
282277
finally:
283278
shutil.rmtree(temp_dir)
284279

0 commit comments

Comments
 (0)