From 36911b87c4f73213f8c031aff493b8db90bfeb5f Mon Sep 17 00:00:00 2001 From: Alexandre Arnold Date: Mon, 4 Aug 2025 18:18:04 +0200 Subject: [PATCH 01/11] Add none_is_leaf option in tree.map_structure --- keras/src/tree/dmtree_impl.py | 18 +++++++++++++++--- keras/src/tree/optree_impl.py | 6 +++--- keras/src/tree/tree_api.py | 5 +++-- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/keras/src/tree/dmtree_impl.py b/keras/src/tree/dmtree_impl.py index ff9964b43d74..d872a4102b59 100644 --- a/keras/src/tree/dmtree_impl.py +++ b/keras/src/tree/dmtree_impl.py @@ -194,16 +194,28 @@ def flatten_with_path(structure): return flattened -def map_structure(func, *structures): +def map_structure(func, *structures, none_is_leaf=True): if not callable(func): raise TypeError( f"`func` must be callable, got {func} of type {type(func)}" ) + map_func = func + if not none_is_leaf: + def func_skipping_none(*args): + # Check if the reference entry (first one) is None + if args[0] is None: + if not all(s is None for s in args): + raise ValueError("Structure mismatch: some arguments are None, others are not.") + return None + return func(*args) + + map_func = func_skipping_none + def func_traverse_wrapper(s): if is_nested(s): return None - ret = func(s) + ret = map_func(s) if ret is None: return dmtree.MAP_TO_NONE return ret @@ -212,7 +224,7 @@ def func_traverse_wrapper(s): return traverse(func_traverse_wrapper, structures[0]) with TypeErrorRemapping(): - return dmtree.map_structure(func, *structures) + return dmtree.map_structure(map_func, *structures) def map_structure_up_to(shallow_structure, func, *structures): diff --git a/keras/src/tree/optree_impl.py b/keras/src/tree/optree_impl.py index 3d813788e023..1134d8338048 100644 --- a/keras/src/tree/optree_impl.py +++ b/keras/src/tree/optree_impl.py @@ -93,14 +93,14 @@ def flatten_with_path(structure): return list(zip(paths, leaves)) -def map_structure(func, *structures): +def map_structure(func, *structures, none_is_leaf=True): if not structures: raise ValueError("Must provide at least one structure") # Add check for same structures, otherwise optree just maps to shallowest. def func_with_check(*args): if not all( - optree.tree_is_leaf(s, none_is_leaf=True, namespace="keras") + optree.tree_is_leaf(s, none_is_leaf=none_is_leaf, namespace="keras") for s in args ): raise ValueError("Structures don't have the same nested structure.") @@ -109,7 +109,7 @@ def func_with_check(*args): map_func = func_with_check if len(structures) > 1 else func return optree.tree_map( - map_func, *structures, none_is_leaf=True, namespace="keras" + map_func, *structures, none_is_leaf=none_is_leaf, namespace="keras" ) diff --git a/keras/src/tree/tree_api.py b/keras/src/tree/tree_api.py index a4f98f068eec..faebdb54f662 100644 --- a/keras/src/tree/tree_api.py +++ b/keras/src/tree/tree_api.py @@ -160,7 +160,7 @@ def flatten_with_path(structure): @keras_export("keras.tree.map_structure") -def map_structure(func, *structures): +def map_structure(func, *structures, none_is_leaf=True): """Maps `func` through given structures. Examples: @@ -179,6 +179,7 @@ def map_structure(func, *structures): Args: func: A callable that accepts as many arguments as there are structures. *structures: Arbitrarily nested structures of the same layout. + none_is_leaf: If True, None is treated as a leaf. Returns: A new structure with the same layout as the given ones. @@ -189,7 +190,7 @@ def map_structure(func, *structures): the nested structures don't match according to the rules of `assert_same_structure`. """ - return tree_impl.map_structure(func, *structures) + return tree_impl.map_structure(func, *structures, none_is_leaf=none_is_leaf) @keras_export("keras.tree.map_structure_up_to") From 12995c47570d5458ea48ad830eef04fd25317fa5 Mon Sep 17 00:00:00 2001 From: Alexandre Arnold Date: Mon, 4 Aug 2025 18:23:02 +0200 Subject: [PATCH 02/11] Support None for optional inputs in model.fit --- keras/src/backend/tensorflow/trainer.py | 5 +++++ .../data_adapters/array_data_adapter.py | 12 ++++++++--- .../data_adapters/data_adapter_utils.py | 20 ++++++++++++++----- .../data_adapters/generator_data_adapter.py | 10 ++++++++++ .../data_adapters/grain_dataset_adapter.py | 10 +++++++--- .../data_adapters/tf_dataset_adapter.py | 6 ++++-- .../torch_data_loader_adapter.py | 4 +++- 7 files changed, 53 insertions(+), 14 deletions(-) diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index fa2f5770098b..529d12f99b96 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -52,6 +52,11 @@ def distribute_reduction_method(self, value): def train_step(self, data): x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) + # Convert TF Optional implementations to None + x = tree.map_structure( + lambda i: None if isinstance(i, tf.experimental.Optional) else i, x + ) + # Forward pass with tf.GradientTape() as tape: if self._call_has_training_arg: diff --git a/keras/src/trainers/data_adapters/array_data_adapter.py b/keras/src/trainers/data_adapters/array_data_adapter.py index e732f28688bd..87db9aac7032 100644 --- a/keras/src/trainers/data_adapters/array_data_adapter.py +++ b/keras/src/trainers/data_adapters/array_data_adapter.py @@ -76,7 +76,9 @@ def __init__( inputs = data_adapter_utils.pack_x_y_sample_weight(x, y, sample_weight) data_adapter_utils.check_data_cardinality(inputs) - num_samples = set(i.shape[0] for i in tree.flatten(inputs)).pop() + num_samples = set( + i.shape[0] for i in tree.flatten(inputs) if i is not None + ).pop() self._num_samples = num_samples self._inputs = inputs @@ -269,7 +271,9 @@ def slice_and_convert(sliceable): x = convert_to_tensor(x) return x - return tree.map_structure(slice_and_convert, self.array) + return tree.map_structure( + slice_and_convert, self.array, none_is_leaf=False + ) def __len__(self): return len(self.array[0]) @@ -337,7 +341,9 @@ def _get_iterator(self, slice_and_convert_fn, inputs): slice_indices_and_convert_fn = functools.partial( slice_and_convert_fn, indices=indices ) - yield tree.map_structure(slice_indices_and_convert_fn, inputs) + yield tree.map_structure( + slice_indices_and_convert_fn, inputs, none_is_leaf=False + ) @property def num_batches(self): diff --git a/keras/src/trainers/data_adapters/data_adapter_utils.py b/keras/src/trainers/data_adapters/data_adapter_utils.py index 29f51dc7772c..6cad232ada98 100644 --- a/keras/src/trainers/data_adapters/data_adapter_utils.py +++ b/keras/src/trainers/data_adapters/data_adapter_utils.py @@ -101,7 +101,9 @@ def list_to_tuple(maybe_list): def check_data_cardinality(data): - num_samples = set(int(i.shape[0]) for i in tree.flatten(data)) + num_samples = set( + int(i.shape[0]) for i in tree.flatten(data) if i is not None + ) if len(num_samples) > 1: msg = ( "Data cardinality is ambiguous. " @@ -186,7 +188,9 @@ def get_single_tensor_spec(*tensors): else: return backend.KerasTensor(shape=shape, dtype=dtype) - return tree.map_structure(get_single_tensor_spec, *batches) + return tree.map_structure( + get_single_tensor_spec, *batches, none_is_leaf=False + ) def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): @@ -199,6 +203,8 @@ def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): """ from keras.src.utils.module_utils import tensorflow as tf + if keras_tensor is None: + return tf.OptionalSpec(None) if not isinstance(keras_tensor, backend.KerasTensor): raise TypeError( f"Expected a KerasTensor, but got {keras_tensor} of type " @@ -252,7 +258,9 @@ def convert_to_jax_compatible(x): return np.asarray(x) for batch in iterable: - yield tree.map_structure(convert_to_jax_compatible, batch) + yield tree.map_structure( + convert_to_jax_compatible, batch, none_is_leaf=False + ) def get_numpy_iterator(iterable): @@ -268,7 +276,7 @@ def convert_to_numpy(x): return x for batch in iterable: - yield tree.map_structure(convert_to_numpy, batch) + yield tree.map_structure(convert_to_numpy, batch, none_is_leaf=False) def get_torch_dataloader(iterable): @@ -282,7 +290,9 @@ def __init__(self, iterable): def __iter__(self): for batch in self.iterable: - yield tree.map_structure(convert_to_tensor, batch) + yield tree.map_structure( + convert_to_tensor, batch, none_is_leaf=False + ) dataset = ConverterIterableDataset(iterable) # `batch_size=None` indicates that we should not re-batch diff --git a/keras/src/trainers/data_adapters/generator_data_adapter.py b/keras/src/trainers/data_adapters/generator_data_adapter.py index 50603e99c7d6..9285d2ef74e0 100644 --- a/keras/src/trainers/data_adapters/generator_data_adapter.py +++ b/keras/src/trainers/data_adapters/generator_data_adapter.py @@ -32,6 +32,8 @@ def get_tf_dataset(self): from keras.src.utils.module_utils import tensorflow as tf def convert_to_tf(x, spec): + if isinstance(spec, tf.OptionalSpec): + return x if data_adapter_utils.is_scipy_sparse(x): x = data_adapter_utils.scipy_sparse_to_tf_sparse(x) elif data_adapter_utils.is_jax_sparse(x): @@ -50,6 +52,14 @@ def convert_to_tf(x, spec): def get_tf_iterator(): for batch in self.generator(): + batch = tree.map_structure( + ( + lambda i: tf.experimental.Optional.empty(None) + if i is None + else i + ), + batch, + ) batch = tree.map_structure( convert_to_tf, batch, self._output_signature ) diff --git a/keras/src/trainers/data_adapters/grain_dataset_adapter.py b/keras/src/trainers/data_adapters/grain_dataset_adapter.py index 5feb7dcf1a10..472cc520b25e 100644 --- a/keras/src/trainers/data_adapters/grain_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/grain_dataset_adapter.py @@ -80,7 +80,9 @@ def convert_to_numpy(x): class ConvertToNumpy(grain.transforms.Map): def map(self, x): - return tree.map_structure(convert_to_numpy, x) + return tree.map_structure( + convert_to_numpy, x, none_is_leaf=False + ) if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)): dataset = self._dataset.map(ConvertToNumpy()) @@ -109,7 +111,9 @@ def convert_to_jax_compatible(x): class ConvertToJaxCompatible(grain.transforms.Map): def map(self, x): - return tree.map_structure(convert_to_jax_compatible, x) + return tree.map_structure( + convert_to_jax_compatible, x, none_is_leaf=False + ) if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)): dataset = self._dataset.map(ConvertToJaxCompatible()) @@ -139,7 +143,7 @@ def convert_to_tf(x): class ConvertToTF(grain.transforms.Map): def map(self, x): - return tree.map_structure(convert_to_tf, x) + return tree.map_structure(convert_to_tf, x, none_is_leaf=False) # `tf.data.Dataset.from_generator` does not support lists as output. # We convert lists to tuples. diff --git a/keras/src/trainers/data_adapters/tf_dataset_adapter.py b/keras/src/trainers/data_adapters/tf_dataset_adapter.py index 3a3cfeb4bb7a..492deb764c3e 100644 --- a/keras/src/trainers/data_adapters/tf_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/tf_dataset_adapter.py @@ -38,7 +38,9 @@ def get_numpy_iterator(self): from keras.src.backend.tensorflow.core import convert_to_numpy for batch in self._dataset: - yield tree.map_structure(convert_to_numpy, batch) + yield tree.map_structure( + convert_to_numpy, batch, none_is_leaf=False + ) def get_jax_iterator(self): from keras.src.backend.tensorflow.core import convert_to_numpy @@ -52,7 +54,7 @@ def convert_to_jax(x): return convert_to_numpy(x) for batch in self._dataset: - yield tree.map_structure(convert_to_jax, batch) + yield tree.map_structure(convert_to_jax, batch, none_is_leaf=False) def get_tf_dataset(self): return self._dataset diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py index 565261d0299a..f0b2f524f4dd 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -35,7 +35,9 @@ def get_numpy_iterator(self): for batch in self._dataloader: # shared memory using `np.asarray` yield tuple( - tree.map_structure(lambda x: np.asarray(x.cpu()), batch) + tree.map_structure( + lambda x: np.asarray(x.cpu()), batch, none_is_leaf=False + ) ) def get_jax_iterator(self): From d1f5ad663437ae65319c5794ec02015357f6a7cf Mon Sep 17 00:00:00 2001 From: Alexandre Arnold Date: Tue, 5 Aug 2025 10:52:59 +0200 Subject: [PATCH 03/11] Fix formatting (line too long) --- keras/src/tree/dmtree_impl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/keras/src/tree/dmtree_impl.py b/keras/src/tree/dmtree_impl.py index d872a4102b59..7900c554e5a9 100644 --- a/keras/src/tree/dmtree_impl.py +++ b/keras/src/tree/dmtree_impl.py @@ -206,7 +206,10 @@ def func_skipping_none(*args): # Check if the reference entry (first one) is None if args[0] is None: if not all(s is None for s in args): - raise ValueError("Structure mismatch: some arguments are None, others are not.") + raise ValueError( + "Structure mismatch: some arguments are None, others " + "are not." + ) return None return func(*args) From 0c9f605b4821b40061514d7c6680e749e1285bb1 Mon Sep 17 00:00:00 2001 From: Alexandre Arnold Date: Tue, 5 Aug 2025 14:54:32 +0200 Subject: [PATCH 04/11] Support None for optional inputs in model.evaluate & model.predict --- keras/src/backend/tensorflow/trainer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index 529d12f99b96..d0a929bed90e 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -51,11 +51,7 @@ def distribute_reduction_method(self, value): def train_step(self, data): x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) - - # Convert TF Optional implementations to None - x = tree.map_structure( - lambda i: None if isinstance(i, tf.experimental.Optional) else i, x - ) + x = self._convert_optional_to_none(x) # Forward pass with tf.GradientTape() as tape: @@ -91,6 +87,7 @@ def train_step(self, data): def test_step(self, data): x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) + x = self._convert_optional_to_none(x) if self._call_has_training_arg: y_pred = self(x, training=False) else: @@ -106,12 +103,19 @@ def test_step(self, data): def predict_step(self, data): x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) + x = self._convert_optional_to_none(x) if self._call_has_training_arg: y_pred = self(x, training=False) else: y_pred = self(x) return y_pred + def _convert_optional_to_none(self, x): + # Convert TF Optional implementations to None + return tree.map_structure( + lambda i: None if isinstance(i, tf.experimental.Optional) else i, x + ) + def _make_function(self, step_function): @tf.autograph.experimental.do_not_convert def one_step_on_data(data): From 01015a19b36ead295c41bcf9986300e4be3c3f9a Mon Sep 17 00:00:00 2001 From: Alexandre Arnold Date: Tue, 5 Aug 2025 15:28:09 +0200 Subject: [PATCH 05/11] Fix formatting --- keras/src/tree/dmtree_impl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras/src/tree/dmtree_impl.py b/keras/src/tree/dmtree_impl.py index 7900c554e5a9..bad4ddafe462 100644 --- a/keras/src/tree/dmtree_impl.py +++ b/keras/src/tree/dmtree_impl.py @@ -202,6 +202,7 @@ def map_structure(func, *structures, none_is_leaf=True): map_func = func if not none_is_leaf: + def func_skipping_none(*args): # Check if the reference entry (first one) is None if args[0] is None: @@ -212,7 +213,7 @@ def func_skipping_none(*args): ) return None return func(*args) - + map_func = func_skipping_none def func_traverse_wrapper(s): From 687a8658adc396f8076202dbadec27ac798786c3 Mon Sep 17 00:00:00 2001 From: Alexandre Arnold Date: Thu, 7 Aug 2025 11:35:57 +0200 Subject: [PATCH 06/11] Improve none_is_leaf docstring in tree.map_structure --- keras/src/tree/tree_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras/src/tree/tree_api.py b/keras/src/tree/tree_api.py index faebdb54f662..89b864333e3e 100644 --- a/keras/src/tree/tree_api.py +++ b/keras/src/tree/tree_api.py @@ -179,7 +179,9 @@ def map_structure(func, *structures, none_is_leaf=True): Args: func: A callable that accepts as many arguments as there are structures. *structures: Arbitrarily nested structures of the same layout. - none_is_leaf: If True, None is treated as a leaf. + none_is_leaf: If True, `func` will be called on `None` leaves. If False, + `None` values are not passed to `func` and are returned in the + output directly. Returns: A new structure with the same layout as the given ones. From abe2056370b240adf1c82c874bfe0c8f4da774df Mon Sep 17 00:00:00 2001 From: Alexandre Arnold Date: Thu, 7 Aug 2025 11:49:21 +0200 Subject: [PATCH 07/11] Improve error message for structure mismatch --- keras/src/tree/dmtree_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/tree/dmtree_impl.py b/keras/src/tree/dmtree_impl.py index bad4ddafe462..5e4132d419a9 100644 --- a/keras/src/tree/dmtree_impl.py +++ b/keras/src/tree/dmtree_impl.py @@ -209,7 +209,7 @@ def func_skipping_none(*args): if not all(s is None for s in args): raise ValueError( "Structure mismatch: some arguments are None, others " - "are not." + f"are not. Received arguments: {args}." ) return None return func(*args) From c5b636abe1028c9d33711d1520b475fb71522cae Mon Sep 17 00:00:00 2001 From: Alexandre Arnold Date: Thu, 7 Aug 2025 14:48:12 +0200 Subject: [PATCH 08/11] Simplify conversion from None to TF Optional (for TF dataset) --- .../trainers/data_adapters/generator_data_adapter.py | 12 ++---------- .../trainers/data_adapters/grain_dataset_adapter.py | 4 +++- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/keras/src/trainers/data_adapters/generator_data_adapter.py b/keras/src/trainers/data_adapters/generator_data_adapter.py index 9285d2ef74e0..186e45da93de 100644 --- a/keras/src/trainers/data_adapters/generator_data_adapter.py +++ b/keras/src/trainers/data_adapters/generator_data_adapter.py @@ -32,8 +32,8 @@ def get_tf_dataset(self): from keras.src.utils.module_utils import tensorflow as tf def convert_to_tf(x, spec): - if isinstance(spec, tf.OptionalSpec): - return x + if x is None: + return tf.experimental.Optional.empty(None) if data_adapter_utils.is_scipy_sparse(x): x = data_adapter_utils.scipy_sparse_to_tf_sparse(x) elif data_adapter_utils.is_jax_sparse(x): @@ -52,14 +52,6 @@ def convert_to_tf(x, spec): def get_tf_iterator(): for batch in self.generator(): - batch = tree.map_structure( - ( - lambda i: tf.experimental.Optional.empty(None) - if i is None - else i - ), - batch, - ) batch = tree.map_structure( convert_to_tf, batch, self._output_signature ) diff --git a/keras/src/trainers/data_adapters/grain_dataset_adapter.py b/keras/src/trainers/data_adapters/grain_dataset_adapter.py index 472cc520b25e..af356257ef1d 100644 --- a/keras/src/trainers/data_adapters/grain_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/grain_dataset_adapter.py @@ -135,6 +135,8 @@ def map(self, x): def get_tf_dataset(self): def convert_to_tf(x): + if x is None: + return tf.experimental.Optional.empty(None) if data_adapter_utils.is_scipy_sparse(x): x = data_adapter_utils.scipy_sparse_to_tf_sparse(x) elif data_adapter_utils.is_jax_sparse(x): @@ -143,7 +145,7 @@ def convert_to_tf(x): class ConvertToTF(grain.transforms.Map): def map(self, x): - return tree.map_structure(convert_to_tf, x, none_is_leaf=False) + return tree.map_structure(convert_to_tf, x) # `tf.data.Dataset.from_generator` does not support lists as output. # We convert lists to tuples. From f77a4cc9100c2c6024e2e42ea9c5f65de04238c8 Mon Sep 17 00:00:00 2001 From: Alexandre Arnold Date: Fri, 15 Aug 2025 22:12:46 +0200 Subject: [PATCH 09/11] Add tests for model fit/evaluate/predict with optional inputs --- keras/src/models/model_test.py | 49 ++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 6ed7d3c6543e..157ae3431f68 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -157,6 +157,23 @@ def call(self, x): return model +def _get_model_optional_inputs(): + class OptionalInputLayer(layers.Layer): + def __init__(self): + super().__init__() + self.dense = layers.Dense(2) + + def call(self, a, b=None): + x = a if b is None else a + b + return self.dense(x) + + x1 = Input((2,), name="x1") + x2 = Input((2,), name="x2", optional=True) + y = OptionalInputLayer()(x1, x2) + model = Model({"x1": x1, "x2": x2}, y) + return model + + def _get_variable_value_by_path(variables, path): for v in variables: if v.path == path: @@ -1219,6 +1236,38 @@ def test_functional_deeply_nested_outputs_struct_losses(self): ) self.assertListEqual(hist_keys, ref_keys) + @parameterized.named_parameters( + ("optional_none", True), ("optional_tensor", False) + ) + def test_functional_optional_inputs(self, is_optional_none): + model = _get_model_optional_inputs() + x1 = np.ones((2, 2)) + x2 = None if is_optional_none else np.ones((2, 2)) + y_true = np.ones((2, 2)) + + model.compile(loss=losses.MeanSquaredError) + model.fit(x={"x1": x1, "x2": x2}, y=y_true) + model.evaluate(x={"x1": x1, "x2": x2}, y=y_true) + model.predict(x={"x1": x1, "x2": x2}) + + @parameterized.named_parameters( + ("optional_none", True), ("optional_tensor", False) + ) + def test_functional_optional_inputs_generator(self, is_optional_none): + model = _get_model_optional_inputs() + x1 = np.ones((2, 2)) + x2 = None if is_optional_none else np.ones((2, 2)) + y_true = np.ones((2, 2)) + + def data_generator(with_y=True): + for _ in range(4): + yield ({"x1": x1, "x2": x2},) + ((y_true,) if with_y else ()) + + model.compile(loss=losses.MeanSquaredError) + model.fit(data_generator()) + model.evaluate(data_generator()) + model.predict(data_generator(with_y=False)) + def test_export_error(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = _get_model() From a57996c5aa2f7a19e319a567d94025a6e3d1797c Mon Sep 17 00:00:00 2001 From: Alexandre Arnold Date: Sat, 16 Aug 2025 10:33:09 +0200 Subject: [PATCH 10/11] Improve model.compile params in tests --- keras/src/models/model_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 157ae3431f68..f4d8850a5302 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1245,7 +1245,7 @@ def test_functional_optional_inputs(self, is_optional_none): x2 = None if is_optional_none else np.ones((2, 2)) y_true = np.ones((2, 2)) - model.compile(loss=losses.MeanSquaredError) + model.compile(loss="mse", optimizer="adam") model.fit(x={"x1": x1, "x2": x2}, y=y_true) model.evaluate(x={"x1": x1, "x2": x2}, y=y_true) model.predict(x={"x1": x1, "x2": x2}) @@ -1263,7 +1263,7 @@ def data_generator(with_y=True): for _ in range(4): yield ({"x1": x1, "x2": x2},) + ((y_true,) if with_y else ()) - model.compile(loss=losses.MeanSquaredError) + model.compile(loss="mse", optimizer="adam") model.fit(data_generator()) model.evaluate(data_generator()) model.predict(data_generator(with_y=False)) From e9170d7024d5769e4e45d4148b1e4d4e79cd2817 Mon Sep 17 00:00:00 2001 From: Alexandre Arnold Date: Thu, 21 Aug 2025 15:02:42 +0200 Subject: [PATCH 11/11] Enable JIT compilation --- keras/src/backend/tensorflow/trainer.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index d0a929bed90e..ae49e7f38534 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -1,4 +1,5 @@ import contextlib +import functools import warnings import numpy as np @@ -51,7 +52,6 @@ def distribute_reduction_method(self, value): def train_step(self, data): x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) - x = self._convert_optional_to_none(x) # Forward pass with tf.GradientTape() as tape: @@ -87,7 +87,6 @@ def train_step(self, data): def test_step(self, data): x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) - x = self._convert_optional_to_none(x) if self._call_has_training_arg: y_pred = self(x, training=False) else: @@ -103,18 +102,26 @@ def test_step(self, data): def predict_step(self, data): x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) - x = self._convert_optional_to_none(x) if self._call_has_training_arg: y_pred = self(x, training=False) else: y_pred = self(x) return y_pred - def _convert_optional_to_none(self, x): - # Convert TF Optional implementations to None - return tree.map_structure( - lambda i: None if isinstance(i, tf.experimental.Optional) else i, x - ) + def _autoconvert_optionals(self, step_func): + # Wrapper converting (nested) TF Optional in input data to None + @functools.wraps(step_func) + def wrapper(data): + converted_data = tree.map_structure( + lambda i: ( + None if isinstance(i, tf.experimental.Optional) else i + ), + data, + ) + result = step_func(converted_data) + return result + + return wrapper def _make_function(self, step_function): @tf.autograph.experimental.do_not_convert @@ -134,6 +141,7 @@ def one_step_on_data(data): reduce_retracing=True, jit_compile=self.jit_compile, ) + one_step_on_data = self._autoconvert_optionals(one_step_on_data) @tf.autograph.experimental.do_not_convert def multi_step_on_iterator(iterator): @@ -262,6 +270,7 @@ def one_step_on_data(data): one_step_on_data = tf.function( one_step_on_data, reduce_retracing=True, jit_compile=True ) + one_step_on_data = self._autoconvert_optionals(one_step_on_data) @tf.autograph.experimental.do_not_convert def one_step_on_data_distributed(data):