Skip to content

Commit 70c57cc

Browse files
Add 7B presets for Mistral (keras-team#1436)
* Refactor the checkpoints script * Add the 7B preset for Mistral * Upate the preset version [skip ci] * Fix the bug in Mistral preprocessor * Fix merge artifacts * Fix the tokenizer test [skip ci] * Mark smallest preset test as extra_large for now [skip ci]
1 parent b8045f9 commit 70c57cc

11 files changed

+432
-369
lines changed

keras_nlp/models/mistral/mistral_backbone.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
15+
1416
from keras_nlp.api_export import keras_nlp_export
1517
from keras_nlp.backend import keras
1618
from keras_nlp.backend import ops
@@ -19,9 +21,11 @@
1921
from keras_nlp.models.mistral.mistral_layer_norm import (
2022
MistralLayerNormalization,
2123
)
24+
from keras_nlp.models.mistral.mistral_presets import backbone_presets
2225
from keras_nlp.models.mistral.mistral_transformer_decoder import (
2326
MistralTransformerDecoder,
2427
)
28+
from keras_nlp.utils.python_utils import classproperty
2529

2630

2731
def _mistral_kernel_initializer(stddev=0.02):
@@ -196,3 +200,7 @@ def get_config(self):
196200
}
197201
)
198202
return config
203+
204+
@classproperty
205+
def presets(cls):
206+
return copy.deepcopy(backbone_presets)

keras_nlp/models/mistral/mistral_backbone_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,29 @@ def test_num_parameters(self):
5454
model = MistralBackbone(**self.init_kwargs)
5555
# Reference value calculated using the PyTorch model
5656
self.assertEqual(model.count_params(), 2704)
57+
58+
@pytest.mark.extra_large
59+
def test_smallest_preset(self):
60+
self.run_preset_test(
61+
cls=MistralBackbone,
62+
preset="mistral_7b_en",
63+
input_data={
64+
"token_ids": ops.array([[1, 1824, 349, 524, 11234, 28804]]),
65+
"padding_mask": ops.ones((1, 6), dtype="int32"),
66+
},
67+
expected_output_shape=(1, 6, 4096),
68+
# The forward pass from a preset should be stable!
69+
# Reference values computed using PyTorch HF model.
70+
expected_partial_output=ops.array(
71+
[-1.6875, 0.5117, -1.7188, 2.3125, -0.0996]
72+
),
73+
)
74+
75+
@pytest.mark.extra_large
76+
def test_all_presets(self):
77+
for preset in MistralBackbone.presets:
78+
self.run_preset_test(
79+
cls=MistralBackbone,
80+
preset=preset,
81+
input_data=self.input_data,
82+
)

keras_nlp/models/mistral/mistral_causal_lm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415

1516
from keras_nlp.api_export import keras_nlp_export
1617
from keras_nlp.backend import keras
@@ -20,6 +21,7 @@
2021
from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import (
2122
MistralCausalLMPreprocessor,
2223
)
24+
from keras_nlp.models.mistral.mistral_presets import backbone_presets
2325
from keras_nlp.utils.python_utils import classproperty
2426

2527

@@ -211,3 +213,7 @@ def next(prompt, cache, index):
211213
"token_ids": token_ids,
212214
"padding_mask": padding_mask,
213215
}
216+
217+
@classproperty
218+
def presets(cls):
219+
return copy.deepcopy(backbone_presets)

keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,20 @@ def generate_preprocess(
131131
x,
132132
sequence_length=None,
133133
):
134+
"""Covert strings to integer token input for generation.
135+
136+
Similar to calling the layer for training, this method takes in strings
137+
or tensor strings, tokenizes and packs the input, and computes a padding
138+
mask masking all inputs not filled in with a padded value.
139+
140+
Unlike calling the layer for training, this method does not compute
141+
labels and will never append a `tokenizer.end_token_id` to the end of
142+
the sequence (as generation is expected to continue at the end of the
143+
inputted prompt).
144+
"""
145+
if not self.built:
146+
self.build(None)
147+
134148
x = convert_inputs_to_list_of_tensor_segments(x)[0]
135149
x = self.tokenizer(x)
136150
token_ids, padding_mask = self.packer(

keras_nlp/models/mistral/mistral_causal_lm_preprocessor_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import os
1616

17+
import pytest
18+
1719
from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import (
1820
MistralCausalLMPreprocessor,
1921
)
@@ -79,3 +81,12 @@ def test_generate_postprocess(self):
7981
preprocessor = MistralCausalLMPreprocessor(**self.init_kwargs)
8082
x = preprocessor.generate_postprocess(input_data)
8183
self.assertAllEqual(x, "the quick brown fox")
84+
85+
@pytest.mark.extra_large
86+
def test_all_presets(self):
87+
for preset in MistralCausalLMPreprocessor.presets:
88+
self.run_preset_test(
89+
cls=MistralCausalLMPreprocessor,
90+
preset=preset,
91+
input_data=self.input_data,
92+
)

keras_nlp/models/mistral/mistral_preprocessor.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415

1516
from keras_nlp.api_export import keras_nlp_export
1617
from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker
18+
from keras_nlp.models.mistral.mistral_presets import backbone_presets
1719
from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer
1820
from keras_nlp.models.preprocessor import Preprocessor
1921
from keras_nlp.utils.keras_utils import (
@@ -121,15 +123,21 @@ def __init__(
121123
):
122124
super().__init__(**kwargs)
123125
self.tokenizer = tokenizer
126+
self.packer = None
127+
self.add_start_token = add_start_token
128+
self.add_end_token = add_end_token
129+
self.sequence_length = sequence_length
130+
131+
def build(self, input_shape):
132+
# Defer packer creation to `build()` so that we can be sure tokenizer
133+
# assets have loaded when restoring a saved model.
124134
self.packer = StartEndPacker(
125135
start_value=self.tokenizer.start_token_id,
126136
end_value=self.tokenizer.end_token_id,
127-
sequence_length=sequence_length,
137+
sequence_length=self.sequence_length,
128138
return_padding_mask=True,
129139
)
130-
self.add_start_token = add_start_token
131-
self.add_end_token = add_end_token
132-
self.sequence_length = sequence_length
140+
self.built = True
133141

134142
def get_config(self):
135143
config = super().get_config()
@@ -184,3 +192,7 @@ def sequence_length(self, value):
184192
@classproperty
185193
def tokenizer_cls(cls):
186194
return MistralTokenizer
195+
196+
@classproperty
197+
def presets(cls):
198+
return copy.deepcopy(backbone_presets)

keras_nlp/models/mistral/mistral_preprocessor_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import os
1616

17+
import pytest
18+
1719
from keras_nlp.models.mistral.mistral_preprocessor import MistralPreprocessor
1820
from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer
1921
from keras_nlp.tests.test_case import TestCase
@@ -57,3 +59,12 @@ def test_errors_for_2d_list_input(self):
5759
ambiguous_input = [["one", "two"], ["three", "four"]]
5860
with self.assertRaises(ValueError):
5961
preprocessor(ambiguous_input)
62+
63+
@pytest.mark.extra_large
64+
def test_all_presets(self):
65+
for preset in MistralPreprocessor.presets:
66+
self.run_preset_test(
67+
cls=MistralPreprocessor,
68+
preset=preset,
69+
input_data=self.input_data,
70+
)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Mistral model preset configurations."""
15+
16+
# Metadata for loading pretrained model weights.
17+
backbone_presets = {
18+
"mistral_7b_en": {
19+
"metadata": {
20+
"description": "Mistral 7B base model",
21+
"params": 7241732096,
22+
"official_name": "Mistral",
23+
"path": "mistral",
24+
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
25+
},
26+
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/3",
27+
},
28+
"mistral_instruct_7b_en": {
29+
"metadata": {
30+
"description": "Mistral 7B instruct model",
31+
"params": 7241732096,
32+
"official_name": "Mistral",
33+
"path": "mistral",
34+
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
35+
},
36+
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/3",
37+
},
38+
}

keras_nlp/models/mistral/mistral_tokenizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
15+
1416
from keras_nlp.api_export import keras_nlp_export
17+
from keras_nlp.models.mistral.mistral_presets import backbone_presets
1518
from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
19+
from keras_nlp.utils.python_utils import classproperty
1620

1721

1822
@keras_nlp_export("keras_nlp.models.MistralTokenizer")
@@ -77,3 +81,7 @@ def set_proto(self, proto):
7781
else:
7882
self.start_token_id = None
7983
self.end_token_id = None
84+
85+
@classproperty
86+
def presets(cls):
87+
return copy.deepcopy(backbone_presets)

keras_nlp/models/mistral/mistral_tokenizer_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import os
1616

17+
import pytest
18+
1719
from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer
1820
from keras_nlp.tests.test_case import TestCase
1921

@@ -44,3 +46,21 @@ def test_errors_missing_special_tokens(self):
4446
self.get_test_data_dir(), "no_special_token_vocab.spm"
4547
)
4648
)
49+
50+
@pytest.mark.large
51+
def test_smallest_preset(self):
52+
self.run_preset_test(
53+
cls=MistralTokenizer,
54+
preset="mistral_7b_en",
55+
input_data=["The quick brown fox."],
56+
expected_output=[[415, 2936, 9060, 285, 1142, 28723]],
57+
)
58+
59+
@pytest.mark.extra_large
60+
def test_all_presets(self):
61+
for preset in MistralTokenizer.presets:
62+
self.run_preset_test(
63+
cls=MistralTokenizer,
64+
preset=preset,
65+
input_data=self.input_data,
66+
)

0 commit comments

Comments
 (0)