|
4 | 4 | import keras
|
5 | 5 | import numpy as np
|
6 | 6 | import pytest
|
| 7 | +from absl.testing import parameterized |
7 | 8 |
|
8 | 9 | from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier
|
9 | 10 | from keras_hub.src.models.causal_lm import CausalLM
|
@@ -107,6 +108,100 @@ def test_summary_without_preprocessor(self):
|
107 | 108 | model.summary(print_fn=lambda x, line_break=False: summary.append(x))
|
108 | 109 | self.assertNotRegex("\n".join(summary), "Preprocessor:")
|
109 | 110 |
|
| 111 | + @pytest.mark.large |
| 112 | + @parameterized.named_parameters( |
| 113 | + { |
| 114 | + "testcase_name": "load_with_quantized_weights", |
| 115 | + "load_weights": True, |
| 116 | + "dtype_override": None, |
| 117 | + "expected_dtype": "int8", |
| 118 | + }, |
| 119 | + { |
| 120 | + "testcase_name": "override_dtype_without_loading_weights", |
| 121 | + "load_weights": False, |
| 122 | + "dtype_override": "float32", |
| 123 | + "expected_dtype": "float32", |
| 124 | + }, |
| 125 | + ) |
| 126 | + def test_quantized_preset_loading_and_saving( |
| 127 | + self, load_weights, dtype_override, expected_dtype |
| 128 | + ): |
| 129 | + # Create, quantize, and save the model preset. |
| 130 | + save_dir = self.get_temp_dir() |
| 131 | + task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2) |
| 132 | + task.quantize(mode="int8") |
| 133 | + task.save_to_preset(save_dir) |
| 134 | + |
| 135 | + # Verify that all necessary files were created. |
| 136 | + path = pathlib.Path(save_dir) |
| 137 | + self.assertTrue(os.path.exists(path / CONFIG_FILE)) |
| 138 | + self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE)) |
| 139 | + self.assertTrue(os.path.exists(path / METADATA_FILE)) |
| 140 | + self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE)) |
| 141 | + self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE)) |
| 142 | + |
| 143 | + # Verify the contents of the task config file. |
| 144 | + task_config = load_json(save_dir, TASK_CONFIG_FILE) |
| 145 | + self.assertNotIn("build_config", task_config) |
| 146 | + self.assertNotIn("compile_config", task_config) |
| 147 | + self.assertIn("backbone", task_config["config"]) |
| 148 | + self.assertIn("preprocessor", task_config["config"]) |
| 149 | + self.assertEqual(BertTextClassifier, check_config_class(task_config)) |
| 150 | + |
| 151 | + # Restore the task from the preset using parameterized arguments. |
| 152 | + restored_task = TextClassifier.from_preset( |
| 153 | + save_dir, |
| 154 | + num_classes=2, |
| 155 | + load_weights=load_weights, |
| 156 | + dtype=dtype_override, |
| 157 | + ) |
| 158 | + |
| 159 | + # Check that the layers have the expected data type. |
| 160 | + for layer in restored_task._flatten_layers(): |
| 161 | + if isinstance(layer, keras.layers.Dense) and layer.name != "logits": |
| 162 | + self.assertEqual( |
| 163 | + layer.kernel.dtype, |
| 164 | + expected_dtype, |
| 165 | + f"Layer '{layer.name}' kernel " |
| 166 | + f"should have dtype '{expected_dtype}'", |
| 167 | + ) |
| 168 | + |
| 169 | + # Ensure inference runs without errors. |
| 170 | + data = ["the quick brown fox.", "the slow brown fox."] |
| 171 | + _ = restored_task.predict(data) |
| 172 | + |
| 173 | + @pytest.mark.large |
| 174 | + def test_load_quantized_preset_with_dtype_override(self): |
| 175 | + save_dir = self.get_temp_dir() |
| 176 | + task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2) |
| 177 | + task.quantize(mode="int8") |
| 178 | + task.save_to_preset(save_dir) |
| 179 | + |
| 180 | + # Check existence of files. |
| 181 | + path = pathlib.Path(save_dir) |
| 182 | + self.assertTrue(os.path.exists(path / CONFIG_FILE)) |
| 183 | + self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE)) |
| 184 | + self.assertTrue(os.path.exists(path / METADATA_FILE)) |
| 185 | + self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE)) |
| 186 | + self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE)) |
| 187 | + |
| 188 | + # Check the task config (`task.json`). |
| 189 | + task_config = load_json(save_dir, TASK_CONFIG_FILE) |
| 190 | + self.assertTrue("build_config" not in task_config) |
| 191 | + self.assertTrue("compile_config" not in task_config) |
| 192 | + self.assertTrue("backbone" in task_config["config"]) |
| 193 | + self.assertTrue("preprocessor" in task_config["config"]) |
| 194 | + |
| 195 | + # Check the preset directory task class. |
| 196 | + self.assertEqual(BertTextClassifier, check_config_class(task_config)) |
| 197 | + |
| 198 | + # Loading the model in full-precision should cause an error during |
| 199 | + # initialization. The serialized quantized layers include additional |
| 200 | + # quantization specific weights (kernel_scale, etc.) which the |
| 201 | + # full-precision layer is not aware about and can't handle. |
| 202 | + with self.assertRaises(ValueError): |
| 203 | + TextClassifier.from_preset(save_dir, num_classes=2, dtype="float32") |
| 204 | + |
110 | 205 | @pytest.mark.large
|
111 | 206 | def test_save_to_preset(self):
|
112 | 207 | save_dir = self.get_temp_dir()
|
|
0 commit comments