Skip to content

Fix support for optional inputs in model.fit #21548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 21, 2025
Merged

Conversation

neo-alex
Copy link
Contributor

@neo-alex neo-alex commented Aug 4, 2025

Here is an example of a model with 2 inputs, the second one being optional:

class OptionalInputLayer(layers.Layer):
    def __init__(self):
        super().__init__()
        self.dense = layers.Dense(2)

    def call(self, x, y=None):
        z = x if y is None else x + y
        return self.dense(z)

    def compute_output_shape(self, x_shape):
        return x_shape

i1 = Input((2,), name="input1")
i2 = Input((2,), name="input2", optional=True)
outputs = OptionalInputLayer()(i1, i2)
model = Model({"input1": i1, "input2": i2}, outputs)

With this definition, the model can be called in Jax/TF/Torch without issue, even when input2 is None:

model({"input1": np.ones((2, 2)), "input2": None})  # WORKS
model.predict_on_batch({"input1": np.ones((2, 2)), "input2": None})  # WORKS AS WELL

It is even possible to train on a batch when input2 is None:

model.compile(loss=losses.MeanSquaredError)
model.train_on_batch(x={"input1": np.ones((2, 2)), "input2": None}, y=np.ones((2, 2)))  # WORKS

However, doing the same with model.fit API is currently failing on all backends:

# Without generator
model.fit(x={"input1": np.ones((2, 2)), "input2": None}, y=np.ones((2, 2)))  # DOESN'T WORK

# With generator
data_generator = (({"input1": np.ones((2, 2)), "input2": None}, np.ones((2, 2))) for _ in range(3))
model.fit(x=data_generator)  # DOESN'T WORK EITHER

The purpose of this PR is to fix this issue (on Jax/TF/Torch), so that the last code block above becomes possible (btw. I could add 1 or 2 unit tests along those lines to demonstrate the fix but I am unsure where would be the best place for it... model_test maybe?).

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @neo-alex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue where model.fit in Keras fails to properly handle optional inputs provided as None, unlike model.predict_on_batch and model.train_on_batch. My changes ensure that model.fit can correctly process models with optional None inputs across TensorFlow, JAX, and PyTorch backends, aligning its behavior with other training utilities.

Highlights

  • Enabled model.fit with Optional None Inputs: The primary goal of this PR is to allow model.fit to successfully train models where certain inputs are optionally provided as None, which was previously causing failures.
  • Enhanced tree.map_structure Handling of None: I've updated the tree.map_structure utility across Keras to include a none_is_leaf parameter. Setting this to False ensures that None values are traversed as part of the structure rather than being treated as terminal leaves, which is crucial for handling optional inputs.
  • TensorFlow Backend Adaptations: For the TensorFlow backend, I've introduced explicit conversions of tf.experimental.Optional instances to None within the train_step and ensured that None values from generators are correctly converted to tf.experimental.Optional.empty(None) when constructing tf.data.Dataset objects.
  • Improved Data Adapter Robustness: Various data adapters (e.g., ArrayDataAdapter, GeneratorDataAdapter, TFDatasetAdapter, TorchDataLoaderAdapter, GrainDatasetAdapter) have been modified to correctly handle None values when calculating sample cardinalities, creating tensor specifications, or iterating through batches, preventing errors related to missing input shapes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for optional inputs in model.fit by introducing a none_is_leaf parameter to tree.map_structure. This allows None values, which represent optional inputs, to be correctly handled across various data adapters and backends. The changes are logical and consistently applied. However, I've found a potential issue where the logic to handle TensorFlow's Optional type is missing from test_step and predict_step, which could cause problems during evaluation and prediction.

@codecov-commenter
Copy link

codecov-commenter commented Aug 4, 2025

Codecov Report

❌ Patch coverage is 72.09302% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.70%. Comparing base (7bf852c) to head (e9170d7).
⚠️ Report is 38 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/tree/dmtree_impl.py 33.33% 7 Missing and 1 partial ⚠️
...c/trainers/data_adapters/generator_data_adapter.py 0.00% 2 Missing ⚠️
...rc/trainers/data_adapters/grain_dataset_adapter.py 50.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21548      +/-   ##
==========================================
- Coverage   82.72%   82.70%   -0.03%     
==========================================
  Files         567      568       +1     
  Lines       56264    56921     +657     
  Branches     8797     8896      +99     
==========================================
+ Hits        46544    47075     +531     
- Misses       7562     7650      +88     
- Partials     2158     2196      +38     
Flag Coverage Δ
keras 82.50% <72.09%> (-0.03%) ⬇️
keras-jax 63.63% <44.18%> (-0.31%) ⬇️
keras-numpy 58.26% <39.53%> (-0.15%) ⬇️
keras-openvino 34.54% <20.93%> (-0.03%) ⬇️
keras-tensorflow 64.21% <62.79%> (-0.15%) ⬇️
keras-torch 63.78% <44.18%> (-0.21%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR!

I am unsure where would be the best place for it... model_test maybe?).

Yes, this would be the place to test. Ideally, it would test fit, predict and evaluate.

Also, ideally, this would be tested in *_data_adapter_test.py to cover all cases.

Taking a step back, is the goal to handle the case when in the dataset passed to fit "input2" is always None? Or sometimes None sometimes not None. Right now it looks like it's only supporting the latter (always None).

@@ -32,6 +32,8 @@ def get_tf_dataset(self):
from keras.src.utils.module_utils import tensorflow as tf

def convert_to_tf(x, spec):
if isinstance(spec, tf.OptionalSpec):
return x
Copy link
Collaborator

@hertschuh hertschuh Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you just return tf.experimental.Optional.empty(None) here are remove lines 55-62?

Or tf.experimental.Optional.empty(None) is x is None else x?

Either way, lines 55-62 should move here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately no (this is what I tried first): indeed, an error is then raised by tree.map_structure on line 63 because batch and self._output_signature do not have the same structure (more specifically: None leaves in batch do not match None leaves in self._output_signature, which have tf.OptionalSpec instead). This is why I had to convert None leaves in batch first, in lines 55-62 - please let me know if you find a more elegant solution to this issue though (I am also not a fan of having 2 map.structure calls in a row if it is avoidable).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Erratum: sorry, I got confused in my own tests (there is actually no issue using tree.map_structure with None leaves on one structure and tf.OptionalSpec on another, as long as none_is_leaf=True - which is the default). So you are right and I simplified the logic according to your comment in this commit.

@@ -179,6 +179,7 @@ def map_structure(func, *structures):
Args:
func: A callable that accepts as many arguments as there are structures.
*structures: Arbitrarily nested structures of the same layout.
none_is_leaf: If True, None is treated as a leaf.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more details here? The name none_is_leaf is pretty unintuitive actually. Basically, say something like:

none_is_leaf=True causes func to be called on None leaves, and none_is_leaf=False means Nones are not passed to func and are returned in the output directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I improved its docstring accordingly in this commit. By the way, I agree that the name none_is_leaf is not the most intuitive, but it is used consistently throughout the underlying optree library (e.g. here), so I kept the same one.

if not all(s is None for s in args):
raise ValueError(
"Structure mismatch: some arguments are None, others "
"are not."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issues while running map_structure can be hard to debug. Any bit of context can help.

Can you add args?

                    raise ValueError(
                        "Structure mismatch: some arguments are None, others "
                        f"are not: {args}."
                    )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, done in this commit.

def _convert_optional_to_none(self, x):
# Convert TF Optional implementations to None
return tree.map_structure(
lambda i: None if isinstance(i, tf.experimental.Optional) else i, x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you also need to do i.get_value() if i.has_value() else None? So that you support both the None and not None cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are probably right, I will double-check (see also my reply below to your "taking a step back" comment wrt. mixing None and not None cases).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, after deeper investigation, it is not that simple to handle the "mixed" case in Tensorflow backend (when a generator sometimes provides None and sometimes a tensor for optional inputs). In particular, the problem with your suggestion is that i.has_value() can be a symbolic tensor (when code is traced by TF with the default run_eagerly=False), which cannot be evaluated as a Python boolean. A natural solution would be to replace your Python condition with something like tf.cond(i.has_value(), lambda: i.get_value(), lambda: None), unfortunately TF control flow operators don't support None as return values... I believe I found a more complex solution that does the trick to cover this Tensorflow edge case, but I think it deserves a dedicated follow-up PR and would suggest first merging this current one - please see this comment for more details. Thank you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sounds good.

@@ -199,6 +203,8 @@ def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True):
"""
from keras.src.utils.module_utils import tensorflow as tf

if keras_tensor is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually ever happen?

My assumption was that this would need to handle non-None inputs that have optional=True on them (this might require some changes), and then create a tf.OptionalSpec(<the actual tensorspec for the KerasTensor per the code below>).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does actually happen, even if the reason is not intuitive: your assumption makes a lot of sense (ideally we would like optional inputs to be represented by KerasTensor with optional=True like in the model), unfortunately all the code in data_adapters is independent from the model, and the data spec is solely inferred from the first batches of received data (typically here)... which seems indeed a bit brittle and prone to some "hidden" constraints for the first batches of the dataset (e.g. see this error message).

Since it is not possible to infer a proper KerasTensor just from a received None value, the trick I am using is to keep it as None (by using the newly introduced none_is_leaf=False inside get_keras_tensor_spec), which explains then that the line of code you mention is actually needed.

@neo-alex
Copy link
Contributor Author

neo-alex commented Aug 7, 2025

Thank you very much @hertschuh for your insightful review! To answer your "taking a step back" comment, for now the goal is at least to enable optional inputs in model fit/evaluate/predict when they are always or never None. Of course, the best would be to enable both (sometimes None, sometimes a tensor) but this can be a bit tricky since tensor specs in data adapters are actually inferred from the first batches (at least for some backends like Tensorflow) - see my last comment above for more technical details.

I will continue to investigate and see if a solution for mixed values (None & not None) is still possible, given some constraints like "in the first 2 batches, every optional input should include a None value and a tensor one" (similar constraints are already assumed in current data adapters, as seen here). I will come back to you shortly about this.

@hertschuh
Copy link
Collaborator

Thank you very much @hertschuh for your insightful review! To answer your "taking a step back" comment, for now the goal is at least to enable optional inputs in model fit/evaluate/predict when they are always or never None. Of course, the best would be to enable both (sometimes None, sometimes a tensor) but this can be a bit tricky since tensor specs in data adapters are actually inferred from the first batches (at least for some backends like Tensorflow) - see my last comment above for more technical details.

I will continue to investigate and see if a solution for mixed values (None & not None) is still possible, given some constraints like "in the first 2 batches, every optional input should include a None value and a tensor one" (similar constraints are already assumed in current data adapters, as seen here). I will come back to you shortly about this.

You're right, the data spec and and the model inputs are disconnected, which, as you point out, is the source of a number of shortcomings. It might not be possible to mix None and not None without this connection, which will be complex to add. Unless we do the same hack as for dynamic dimensions that you linked.

@neo-alex
Copy link
Contributor Author

Actually, the current solution in this PR already works for model fit/evaluate/predict across Jax/TF/Torch backends on single data batches as well as batches provided by a generator - even when the latter generates "mixed" optional input values (None & not None) except in TF backend where an error is raised in this very specific case. The core reason is that TF backend uses TF Dataset under the hood, which doesn't support None values, so we need to convert optional input values from None/tensor to tf.experimental.Optional in the data adapter, and back to None/tensor as expected by Keras model for optional inputs. But doing this last conversion (depending on whether the Optional has a value) in a way that is traceable by TF AutoGraph (when run_eagerly=False) is not trivial, since TF control flow operators like tf.cond also don't support returning None...

Since my latest comment, I experimented a lot and finally did find a solution that also covers this last edge case (generator with "mixed" optional input values with TF backend). But the solution is more complex than expected and I would rather add it in a dedicated follow-up PR to enable specific discussions about it.

So, if you agree @hertschuh, I would suggest merging this PR first, which already covers most cases. To that end, I just added in this commit some unit tests (within model_test.py as discussed above).

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @neo-alex , this looks good!

Looking forward to the next PR.

Also, one question, what if people do use tf.experimental.Optional if a tf.data.Dataset, but then use a different backend than TF (let's say JAX), shouldn't the conversion happening in TFDatasetAdapter also convert from tf.experimental.Optional to None (if you agree, this can be a separate PR).

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Aug 19, 2025
@hertschuh
Copy link
Collaborator

@neo-alex

It looks like it's failing on GPU. Do you understand why?

@neo-alex
Copy link
Contributor Author

@neo-alex

It looks like it's failing on GPU. Do you understand why?

Yes, it seems to be due to the XLA JIT-compilation (automatically activated on GPU but not on CPU) - when I force jit_compile=True in model.compile, I am able to reproduce the issue even on CPU. I wasn't aware but it turns out that tf.experimental.Optional is unfortunately not JIT-compatible (for instance, the first example on this official TF documentation page fails as soon as jit_compile=True is added to @tf.function) 😞

I think the trick would be to move the Optional->tensor/None conversion just outside of the one_step_on_data function, which is the part that Keras possibly JIT compiles. I believe this should work since a couple of tf.experimental.Optional objects are already successfully present in the code base outside this function (like here for instance).

I will try to push an update for this in the next couple of days and let you know.

@hertschuh
Copy link
Collaborator

I will try to push an update for this in the next couple of days and let you know.

Great, thanks!

@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Aug 21, 2025
@neo-alex
Copy link
Contributor Author

Alright, I just pushed the suggested fix for JIT compilation compatibility (I believe everything should be fine this time, even on GPU). I also made sure that the follow-up PR covering the last TF edge case is still compatible with this change - I will submit it as soon as you merge this one @hertschuh
Thanks!

@hertschuh hertschuh merged commit 19367bc into keras-team:master Aug 21, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants