diff --git a/keras/src/legacy/saving/legacy_h5_format.py b/keras/src/legacy/saving/legacy_h5_format.py index 5b919f80e7c..17cdf314ffb 100644 --- a/keras/src/legacy/saving/legacy_h5_format.py +++ b/keras/src/legacy/saving/legacy_h5_format.py @@ -318,12 +318,14 @@ def save_attributes_to_hdf5_group(group, name, data): group.attrs[name] = data -def load_weights_from_hdf5_group(f, model): +def load_weights_from_hdf5_group(f, model, skip_mismatch=False): """Implements topological (order-based) weight loading. Args: f: A pointer to a HDF5 group. model: Model instance. + skip_mismatch: Boolean, whether to skip loading of weights + where there is a mismatch in the shape of the weights, Raises: ValueError: in case of mismatch between provided layers @@ -379,6 +381,7 @@ def load_weights_from_hdf5_group(f, model): layer, symbolic_weights, weight_values, + skip_mismatch=skip_mismatch, name=f"layer #{k} (named {layer.name})", ) @@ -403,6 +406,7 @@ def load_weights_from_hdf5_group(f, model): model, symbolic_weights, weight_values, + skip_mismatch=skip_mismatch, name="top-level model", ) diff --git a/keras/src/saving/saving_api.py b/keras/src/saving/saving_api.py index 1762a37ed7a..a5b182c4df7 100644 --- a/keras/src/saving/saving_api.py +++ b/keras/src/saving/saving_api.py @@ -249,18 +249,35 @@ def save_weights( @keras_export("keras.saving.load_weights") def load_weights(model, filepath, skip_mismatch=False, **kwargs): filepath_str = str(filepath) + + # Get the legacy kwargs. + objects_to_skip = kwargs.pop("objects_to_skip", None) + by_name = kwargs.pop("by_name", None) + if kwargs: + raise ValueError(f"Invalid keyword arguments: {kwargs}") + if filepath_str.endswith(".keras"): - if kwargs: - raise ValueError(f"Invalid keyword arguments: {kwargs}") + if objects_to_skip is not None: + raise ValueError( + "`objects_to_skip` only supports loading '.weights.h5' files." + f"Received: {filepath}" + ) + if by_name is not None: + raise ValueError( + "`by_name` only supports loading legacy '.h5' or '.hdf5' " + f"files. Received: {filepath}" + ) saving_lib.load_weights_only( model, filepath, skip_mismatch=skip_mismatch ) elif filepath_str.endswith(".weights.h5") or filepath_str.endswith( ".weights.json" ): - objects_to_skip = kwargs.pop("objects_to_skip", None) - if kwargs: - raise ValueError(f"Invalid keyword arguments: {kwargs}") + if by_name is not None: + raise ValueError( + "`by_name` only supports loading legacy '.h5' or '.hdf5' " + f"files. Received: {filepath}" + ) saving_lib.load_weights_only( model, filepath, @@ -268,13 +285,15 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs): objects_to_skip=objects_to_skip, ) elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"): - by_name = kwargs.pop("by_name", False) - if kwargs: - raise ValueError(f"Invalid keyword arguments: {kwargs}") if not h5py: raise ImportError( "Loading a H5 file requires `h5py` to be installed." ) + if objects_to_skip is not None: + raise ValueError( + "`objects_to_skip` only supports loading '.weights.h5' files." + f"Received: {filepath}" + ) with h5py.File(filepath, "r") as f: if "layer_names" not in f.attrs and "model_weights" in f: f = f["model_weights"] @@ -283,7 +302,9 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs): f, model, skip_mismatch ) else: - legacy_h5_format.load_weights_from_hdf5_group(f, model) + legacy_h5_format.load_weights_from_hdf5_group( + f, model, skip_mismatch + ) else: raise ValueError( f"File format not supported: filepath={filepath}. " diff --git a/keras/src/saving/saving_api_test.py b/keras/src/saving/saving_api_test.py index 5466f3077ef..638528eaac7 100644 --- a/keras/src/saving/saving_api_test.py +++ b/keras/src/saving/saving_api_test.py @@ -7,6 +7,7 @@ from absl.testing import parameterized from keras.src import layers +from keras.src.legacy.saving.legacy_h5_format import save_model_to_hdf5 from keras.src.models import Sequential from keras.src.saving import saving_api from keras.src.testing import test_case @@ -53,7 +54,18 @@ def test_save_h5_format(self): """Test saving model in h5 format.""" model = self.get_model() filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5") - saving_api.save_model(model, filepath_h5) + + # Verify the warning. + with mock.patch.object(logging, "warning") as mock_warn: + saving_api.save_model(model, filepath_h5) + mock_warn.assert_called_once_with( + "You are saving your model as an HDF5 file via " + "`model.save()` or `keras.saving.save_model(model)`. " + "This file format is considered legacy. " + "We recommend using instead the native Keras format, " + "e.g. `model.save('my_model.keras')` or " + "`keras.saving.save_model(model, 'my_model.keras')`. " + ) self.assertTrue(os.path.exists(filepath_h5)) os.remove(filepath_h5) @@ -203,18 +215,36 @@ def get_model(self, dtype=None): @parameterized.named_parameters( named_product( + save_format=["keras", "weights.h5", "h5"], source_dtype=["float64", "float32", "float16", "bfloat16"], dest_dtype=["float64", "float32", "float16", "bfloat16"], ) ) - def test_load_keras_weights(self, source_dtype, dest_dtype): + def test_load_weights(self, save_format, source_dtype, dest_dtype): """Test loading keras weights.""" src_model = self.get_model(dtype=source_dtype) - filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5") - src_model.save_weights(filepath) - src_weights = src_model.get_weights() + if save_format == "keras": + filepath = os.path.join(self.get_temp_dir(), "test_weights.keras") + src_model.save(filepath) + elif save_format == "weights.h5": + filepath = os.path.join( + self.get_temp_dir(), "test_weights.weights.h5" + ) + src_model.save_weights(filepath) + elif save_format == "h5": + if "bfloat16" in (source_dtype, dest_dtype): + raise self.skipTest( + "bfloat16 dtype is not supported in legacy h5 format." + ) + filepath = os.path.join(self.get_temp_dir(), "test_weights.h5") + save_model_to_hdf5(src_model, filepath) + else: + raise ValueError(f"Unsupported save format: {save_format}") + dest_model = self.get_model(dtype=dest_dtype) dest_model.load_weights(filepath) + + src_weights = src_model.get_weights() dest_weights = dest_model.get_weights() for orig, loaded in zip(src_weights, dest_weights): self.assertAllClose( @@ -224,13 +254,41 @@ def test_load_keras_weights(self, source_dtype, dest_dtype): rtol=0.01, ) - def test_load_h5_weights_by_name(self): - """Test loading h5 weights by name.""" - model = self.get_model() - filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5") - model.save_weights(filepath) - with self.assertRaisesRegex(ValueError, "Invalid keyword arguments"): - model.load_weights(filepath, by_name=True) + def test_load_weights_invalid_kwargs(self): + src_model = self.get_model() + keras_filepath = os.path.join(self.get_temp_dir(), "test_weights.keras") + weight_h5_filepath = os.path.join( + self.get_temp_dir(), "test_weights.weights.h5" + ) + legacy_h5_filepath = os.path.join( + self.get_temp_dir(), "test_weights.h5" + ) + src_model.save(keras_filepath) + src_model.save_weights(weight_h5_filepath) + save_model_to_hdf5(src_model, legacy_h5_filepath) + + dest_model = self.get_model() + # Test keras file. + with self.assertRaisesRegex( + ValueError, r"only supports loading '.weights.h5' files." + ): + dest_model.load_weights(keras_filepath, objects_to_skip=[]) + with self.assertRaisesRegex( + ValueError, r"only supports loading legacy '.h5' or '.hdf5' files." + ): + dest_model.load_weights(keras_filepath, by_name=True) + with self.assertRaisesRegex(ValueError, r"Invalid keyword arguments"): + dest_model.load_weights(keras_filepath, bad_kwarg=None) + # Test weights.h5 file. + with self.assertRaisesRegex( + ValueError, r"only supports loading legacy '.h5' or '.hdf5' files." + ): + dest_model.load_weights(weight_h5_filepath, by_name=True) + # Test h5 file. + with self.assertRaisesRegex( + ValueError, r"only supports loading '.weights.h5' files." + ): + dest_model.load_weights(legacy_h5_filepath, objects_to_skip=[]) def test_load_weights_invalid_extension(self): """Test loading weights with unsupported extension.""" @@ -251,29 +309,3 @@ def test_load_sharded_weights(self): dest_weights = dest_model.get_weights() for orig, loaded in zip(src_weights, dest_weights): self.assertAllClose(orig, loaded) - - -class SaveModelTestsWarning(test_case.TestCase): - def get_model(self): - return Sequential( - [ - layers.Dense(5, input_shape=(3,)), - layers.Softmax(), - ] - ) - - def test_h5_deprecation_warning(self): - """Test deprecation warning for h5 format.""" - model = self.get_model() - filepath = os.path.join(self.get_temp_dir(), "test_model.h5") - - with mock.patch.object(logging, "warning") as mock_warn: - saving_api.save_model(model, filepath) - mock_warn.assert_called_once_with( - "You are saving your model as an HDF5 file via " - "`model.save()` or `keras.saving.save_model(model)`. " - "This file format is considered legacy. " - "We recommend using instead the native Keras format, " - "e.g. `model.save('my_model.keras')` or " - "`keras.saving.save_model(model, 'my_model.keras')`. " - )