|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import copy |
14 | 15 |
|
15 | 16 | from keras_nlp.api_export import keras_nlp_export
|
16 | 17 | from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker
|
| 18 | +from keras_nlp.models.mistral.mistral_presets import backbone_presets |
17 | 19 | from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer
|
18 | 20 | from keras_nlp.models.preprocessor import Preprocessor
|
19 | 21 | from keras_nlp.utils.keras_utils import (
|
@@ -121,15 +123,21 @@ def __init__(
|
121 | 123 | ):
|
122 | 124 | super().__init__(**kwargs)
|
123 | 125 | self.tokenizer = tokenizer
|
| 126 | + self.packer = None |
| 127 | + self.add_start_token = add_start_token |
| 128 | + self.add_end_token = add_end_token |
| 129 | + self.sequence_length = sequence_length |
| 130 | + |
| 131 | + def build(self, input_shape): |
| 132 | + # Defer packer creation to `build()` so that we can be sure tokenizer |
| 133 | + # assets have loaded when restoring a saved model. |
124 | 134 | self.packer = StartEndPacker(
|
125 | 135 | start_value=self.tokenizer.start_token_id,
|
126 | 136 | end_value=self.tokenizer.end_token_id,
|
127 |
| - sequence_length=sequence_length, |
| 137 | + sequence_length=self.sequence_length, |
128 | 138 | return_padding_mask=True,
|
129 | 139 | )
|
130 |
| - self.add_start_token = add_start_token |
131 |
| - self.add_end_token = add_end_token |
132 |
| - self.sequence_length = sequence_length |
| 140 | + self.built = True |
133 | 141 |
|
134 | 142 | def get_config(self):
|
135 | 143 | config = super().get_config()
|
@@ -184,3 +192,7 @@ def sequence_length(self, value):
|
184 | 192 | @classproperty
|
185 | 193 | def tokenizer_cls(cls):
|
186 | 194 | return MistralTokenizer
|
| 195 | + |
| 196 | + @classproperty |
| 197 | + def presets(cls): |
| 198 | + return copy.deepcopy(backbone_presets) |
0 commit comments