Skip to content

Commit 3bfa89f

Browse files
Support Loading Quantized Models with from_preset() (#2367)
This change resolves an issue with loading quantized models from presets. Previously, the model's serialized `DTypePolicyMap` was not correctly passed to the backbone during loading, which caused failures during initialization of quantized layers. The fix introduces a new `_resolve_dtype` utility function that determines the correct `dtype` for the model based on the following rules: 1. User-specified `dtype`: If a user explicitly provides a `dtype` in the from_preset call (e.g., `from_preset("bert_tiny_en_uncased", num_classes=2, dtype="float32")`), that value is used. 2. Float type casting: If no user `dtype` is provided and the saved `dtype` is a floating-point type (e.g., "float32"), the model will be loaded using the current Keras default `dtype` policy. This allows for safe casting between different floating-point precisions. 3. `DTypePolicyMap`: If no user `dtype` is provided and the saved `dtype` is a complex object (like a `DTypePolicyMap` for quantization), the saved type is used as is. This ensures that quantization configurations are preserved during loading.
1 parent ec906a3 commit 3bfa89f

File tree

5 files changed

+201
-16
lines changed

5 files changed

+201
-16
lines changed

keras_hub/src/models/backbone.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,16 @@ def get_config(self):
9191
}
9292

9393
# Add quantization support by utilizing `DTypePolicyMap`
94-
try:
95-
if isinstance(
96-
self.dtype_policy, keras.dtype_policies.DTypePolicyMap
97-
):
98-
config.update({"dtype": self.dtype_policy})
99-
else:
100-
policy_map = keras.dtype_policies.DTypePolicyMap()
101-
for layer in self._flatten_layers():
102-
if layer.quantization_mode is not None:
103-
policy_map[layer.path] = layer.dtype_policy
104-
if len(policy_map) > 0:
105-
config.update({"dtype": policy_map})
106-
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
107-
except AttributeError:
108-
pass
94+
dtype = self.dtype_policy
95+
if not isinstance(dtype, keras.dtype_policies.DTypePolicyMap):
96+
policy_map = keras.dtype_policies.DTypePolicyMap()
97+
for layer in self._flatten_layers():
98+
if layer.quantization_mode is not None:
99+
policy_map[layer.path] = layer.dtype_policy
100+
if len(policy_map) > 0:
101+
dtype = policy_map
102+
103+
config.update({"dtype": keras.dtype_policies.serialize(dtype)})
109104
return config
110105

111106
@classmethod

keras_hub/src/models/task_test.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import keras
55
import numpy as np
66
import pytest
7+
from absl.testing import parameterized
78

89
from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier
910
from keras_hub.src.models.causal_lm import CausalLM
@@ -107,6 +108,100 @@ def test_summary_without_preprocessor(self):
107108
model.summary(print_fn=lambda x, line_break=False: summary.append(x))
108109
self.assertNotRegex("\n".join(summary), "Preprocessor:")
109110

111+
@pytest.mark.large
112+
@parameterized.named_parameters(
113+
{
114+
"testcase_name": "load_with_quantized_weights",
115+
"load_weights": True,
116+
"dtype_override": None,
117+
"expected_dtype": "int8",
118+
},
119+
{
120+
"testcase_name": "override_dtype_without_loading_weights",
121+
"load_weights": False,
122+
"dtype_override": "float32",
123+
"expected_dtype": "float32",
124+
},
125+
)
126+
def test_quantized_preset_loading_and_saving(
127+
self, load_weights, dtype_override, expected_dtype
128+
):
129+
# Create, quantize, and save the model preset.
130+
save_dir = self.get_temp_dir()
131+
task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2)
132+
task.quantize(mode="int8")
133+
task.save_to_preset(save_dir)
134+
135+
# Verify that all necessary files were created.
136+
path = pathlib.Path(save_dir)
137+
self.assertTrue(os.path.exists(path / CONFIG_FILE))
138+
self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE))
139+
self.assertTrue(os.path.exists(path / METADATA_FILE))
140+
self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE))
141+
self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE))
142+
143+
# Verify the contents of the task config file.
144+
task_config = load_json(save_dir, TASK_CONFIG_FILE)
145+
self.assertNotIn("build_config", task_config)
146+
self.assertNotIn("compile_config", task_config)
147+
self.assertIn("backbone", task_config["config"])
148+
self.assertIn("preprocessor", task_config["config"])
149+
self.assertEqual(BertTextClassifier, check_config_class(task_config))
150+
151+
# Restore the task from the preset using parameterized arguments.
152+
restored_task = TextClassifier.from_preset(
153+
save_dir,
154+
num_classes=2,
155+
load_weights=load_weights,
156+
dtype=dtype_override,
157+
)
158+
159+
# Check that the layers have the expected data type.
160+
for layer in restored_task._flatten_layers():
161+
if isinstance(layer, keras.layers.Dense) and layer.name != "logits":
162+
self.assertEqual(
163+
layer.kernel.dtype,
164+
expected_dtype,
165+
f"Layer '{layer.name}' kernel "
166+
f"should have dtype '{expected_dtype}'",
167+
)
168+
169+
# Ensure inference runs without errors.
170+
data = ["the quick brown fox.", "the slow brown fox."]
171+
_ = restored_task.predict(data)
172+
173+
@pytest.mark.large
174+
def test_load_quantized_preset_with_dtype_override(self):
175+
save_dir = self.get_temp_dir()
176+
task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2)
177+
task.quantize(mode="int8")
178+
task.save_to_preset(save_dir)
179+
180+
# Check existence of files.
181+
path = pathlib.Path(save_dir)
182+
self.assertTrue(os.path.exists(path / CONFIG_FILE))
183+
self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE))
184+
self.assertTrue(os.path.exists(path / METADATA_FILE))
185+
self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE))
186+
self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE))
187+
188+
# Check the task config (`task.json`).
189+
task_config = load_json(save_dir, TASK_CONFIG_FILE)
190+
self.assertTrue("build_config" not in task_config)
191+
self.assertTrue("compile_config" not in task_config)
192+
self.assertTrue("backbone" in task_config["config"])
193+
self.assertTrue("preprocessor" in task_config["config"])
194+
195+
# Check the preset directory task class.
196+
self.assertEqual(BertTextClassifier, check_config_class(task_config))
197+
198+
# Loading the model in full-precision should cause an error during
199+
# initialization. The serialized quantized layers include additional
200+
# quantization specific weights (kernel_scale, etc.) which the
201+
# full-precision layer is not aware about and can't handle.
202+
with self.assertRaises(ValueError):
203+
TextClassifier.from_preset(save_dir, num_classes=2, dtype="float32")
204+
110205
@pytest.mark.large
111206
def test_save_to_preset(self):
112207
save_dir = self.get_temp_dir()

keras_hub/src/utils/preset_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from absl import logging
1111

1212
from keras_hub.src.api_export import keras_hub_export
13+
from keras_hub.src.utils import tensor_utils
1314
from keras_hub.src.utils.keras_utils import print_msg
1415
from keras_hub.src.utils.keras_utils import sharded_weights_available
1516
from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits
@@ -687,6 +688,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
687688
)
688689
# We found a `task.json` with a complete config for our class.
689690
# Forward backbone args.
691+
kwargs["dtype"] = self._resolve_dtype(self.config, kwargs)
690692
backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
691693
if "backbone" in task_config["config"]:
692694
backbone_config = task_config["config"]["backbone"]["config"]
@@ -708,6 +710,53 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
708710
self._load_backbone_weights(task.backbone)
709711
return task
710712

713+
def _resolve_dtype(self, config, kwargs):
714+
"""Resolves the Model's dtype based on the provided config and kwargs.
715+
716+
The data type is resolved based on the following priority:
717+
1. If a user specified dtype is passed, use that.
718+
2. If no user specified dtype is passed, and the save dtype is castable
719+
to the current keras default dtype convert weights on load (float type
720+
to float type).
721+
3. If not user specified dtype is passed, and the save dtype is not
722+
castable to the current default dtype (quantized dtypes). Load the
723+
saved types verbatim.
724+
725+
Args:
726+
config: dict. The model configuration.
727+
kwargs: dict. Additional keyword arguments, potentially including
728+
`dtype`.
729+
730+
Returns:
731+
str, dict, or DTypePolicy. The resolved dtype.
732+
"""
733+
# 1. If a user specified dtype is passed, use that.
734+
if "dtype" in kwargs and kwargs["dtype"] is not None:
735+
return kwargs["dtype"]
736+
737+
saved_dtype = config.get("config", {}).get("dtype")
738+
739+
# If there's no saved dtype, we don't need to do anything.
740+
if saved_dtype is None:
741+
return None
742+
743+
# 2. Check whether the saved dtype is a simple float type.
744+
policy_name = saved_dtype.get("config", {}).get("name")
745+
if policy_name and tensor_utils.is_float_dtype(policy_name):
746+
# If the saved dtype is a float, we can safely cast to the default
747+
# backend float type.
748+
if policy_name != keras.config.dtype_policy().name:
749+
logging.info(
750+
f"Converting weights saved as {policy_name} "
751+
"to the current Keras dtype policy "
752+
f"{keras.config.dtype_policy()}"
753+
)
754+
return keras.config.dtype_policy()
755+
else:
756+
# 3. Otherwise, the dtype is a complex object (e.g. a
757+
# DTypePolicyMap for quantization), and should be used as is.
758+
return saved_dtype
759+
711760
def load_preprocessor(
712761
self, cls, config_file=PREPROCESSOR_CONFIG_FILE, **kwargs
713762
):

keras_hub/src/utils/tensor_utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,29 @@ def is_tensor_type(x):
310310

311311

312312
def is_float_dtype(dtype):
313-
return "float" in keras.backend.standardize_dtype(dtype)
313+
"""
314+
Checks if a dtype is a float type by using a regex.
315+
316+
This function standardizes the input dtype and then uses a regular
317+
expression to perform an exact match. It identifies standard floats,
318+
bfloats, and mixed-precision float types.
319+
320+
For example:
321+
- `is_float_dtype("float32")` returns `True`.
322+
- `is_float_dtype("bfloat16")` returns `True`.
323+
- `is_float_dtype("mixed_float16")` returns `True`.
324+
- `is_float_dtype("int8")` returns `False`.
325+
- `is_float_dtype("int8_from_float32")` returns `False`.
326+
327+
Args:
328+
dtype: str, DTypePolicy. The data type to check.
329+
330+
Returns:
331+
bool: `True` if the dtype is a floating-point type, `False` otherwise.
332+
"""
333+
pattern = re.compile(r"^(mixed_)?(b)?float[0-9]*$")
334+
standardized_dtype = keras.backend.standardize_dtype(dtype)
335+
return pattern.match(standardized_dtype) is not None
314336

315337

316338
def is_int_dtype(dtype):

keras_hub/src/utils/tensor_utils_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras_hub.src.utils.tensor_utils import convert_preprocessing_inputs
99
from keras_hub.src.utils.tensor_utils import convert_preprocessing_outputs
1010
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
11+
from keras_hub.src.utils.tensor_utils import is_float_dtype
1112
from keras_hub.src.utils.tensor_utils import is_tensor_type
1213
from keras_hub.src.utils.tensor_utils import preprocessing_function
1314
from keras_hub.src.utils.tensor_utils import target_gather
@@ -304,3 +305,26 @@ def test_target_gather_invalid_rank(self):
304305
indices = np.array([0, 1], dtype="int32")
305306
with self.assertRaisesRegex(ValueError, "larger than 3"):
306307
_ = target_gather(targets, indices)
308+
309+
310+
class IsFloatDtypeTest(TestCase):
311+
def test_float_dtypes_return_true(self):
312+
float_dtypes = [
313+
"float16",
314+
"float32",
315+
"float64",
316+
"bfloat16",
317+
]
318+
for dtype in float_dtypes:
319+
self.assertTrue(is_float_dtype(dtype))
320+
321+
def test_non_float_dtypes_return_false(self):
322+
non_float_dtypes = [
323+
"int8",
324+
"int32",
325+
"uint8",
326+
"bool",
327+
"string",
328+
]
329+
for dtype in non_float_dtypes:
330+
self.assertFalse(is_float_dtype(dtype))

0 commit comments

Comments
 (0)