-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix support for optional inputs in model.fit #21548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @neo-alex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses a critical issue where model.fit in Keras fails to properly handle optional inputs provided as None, unlike model.predict_on_batch and model.train_on_batch. My changes ensure that model.fit can correctly process models with optional None inputs across TensorFlow, JAX, and PyTorch backends, aligning its behavior with other training utilities.
Highlights
- Enabled model.fit with Optional None Inputs: The primary goal of this PR is to allow model.fit to successfully train models where certain inputs are optionally provided as None, which was previously causing failures.
- Enhanced tree.map_structure Handling of None: I've updated the tree.map_structure utility across Keras to include a none_is_leaf parameter. Setting this to False ensures that None values are traversed as part of the structure rather than being treated as terminal leaves, which is crucial for handling optional inputs.
- TensorFlow Backend Adaptations: For the TensorFlow backend, I've introduced explicit conversions of tf.experimental.Optional instances to None within the train_step and ensured that None values from generators are correctly converted to tf.experimental.Optional.empty(None) when constructing tf.data.Dataset objects.
- Improved Data Adapter Robustness: Various data adapters (e.g., ArrayDataAdapter, GeneratorDataAdapter, TFDatasetAdapter, TorchDataLoaderAdapter, GrainDatasetAdapter) have been modified to correctly handle None values when calculating sample cardinalities, creating tensor specifications, or iterating through batches, preventing errors related to missing input shapes.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for optional inputs in model.fit
by introducing a none_is_leaf
parameter to tree.map_structure
. This allows None
values, which represent optional inputs, to be correctly handled across various data adapters and backends. The changes are logical and consistently applied. However, I've found a potential issue where the logic to handle TensorFlow's Optional
type is missing from test_step
and predict_step
, which could cause problems during evaluation and prediction.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21548 +/- ##
==========================================
- Coverage 82.72% 82.70% -0.03%
==========================================
Files 567 568 +1
Lines 56264 56921 +657
Branches 8797 8896 +99
==========================================
+ Hits 46544 47075 +531
- Misses 7562 7650 +88
- Partials 2158 2196 +38
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
I am unsure where would be the best place for it... model_test maybe?).
Yes, this would be the place to test. Ideally, it would test fit
, predict
and evaluate
.
Also, ideally, this would be tested in *_data_adapter_test.py
to cover all cases.
Taking a step back, is the goal to handle the case when in the dataset passed to fit
"input2" is always None
? Or sometimes None
sometimes not None
. Right now it looks like it's only supporting the latter (always None
).
@@ -32,6 +32,8 @@ def get_tf_dataset(self): | |||
from keras.src.utils.module_utils import tensorflow as tf | |||
|
|||
def convert_to_tf(x, spec): | |||
if isinstance(spec, tf.OptionalSpec): | |||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you just return tf.experimental.Optional.empty(None)
here are remove lines 55-62?
Or tf.experimental.Optional.empty(None) is x is None else x
?
Either way, lines 55-62 should move here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately no (this is what I tried first): indeed, an error is then raised by tree.map_structure
on line 63 because batch
and self._output_signature
do not have the same structure (more specifically: None leaves in batch
do not match None leaves in self._output_signature
, which have tf.OptionalSpec
instead). This is why I had to convert None leaves in batch
first, in lines 55-62 - please let me know if you find a more elegant solution to this issue though (I am also not a fan of having 2 map.structure
calls in a row if it is avoidable).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Erratum: sorry, I got confused in my own tests (there is actually no issue using tree.map_structure
with None leaves on one structure and tf.OptionalSpec on another, as long as none_is_leaf=True
- which is the default). So you are right and I simplified the logic according to your comment in this commit.
keras/src/tree/tree_api.py
Outdated
@@ -179,6 +179,7 @@ def map_structure(func, *structures): | |||
Args: | |||
func: A callable that accepts as many arguments as there are structures. | |||
*structures: Arbitrarily nested structures of the same layout. | |||
none_is_leaf: If True, None is treated as a leaf. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add more details here? The name none_is_leaf
is pretty unintuitive actually. Basically, say something like:
none_is_leaf=True
causes func
to be called on None
leaves, and none_is_leaf=False
means None
s are not passed to func
and are returned in the output directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I improved its docstring accordingly in this commit. By the way, I agree that the name none_is_leaf
is not the most intuitive, but it is used consistently throughout the underlying optree library (e.g. here), so I kept the same one.
keras/src/tree/dmtree_impl.py
Outdated
if not all(s is None for s in args): | ||
raise ValueError( | ||
"Structure mismatch: some arguments are None, others " | ||
"are not." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issues while running map_structure
can be hard to debug. Any bit of context can help.
Can you add args
?
raise ValueError(
"Structure mismatch: some arguments are None, others "
f"are not: {args}."
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, done in this commit.
def _convert_optional_to_none(self, x): | ||
# Convert TF Optional implementations to None | ||
return tree.map_structure( | ||
lambda i: None if isinstance(i, tf.experimental.Optional) else i, x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you also need to do i.get_value() if i.has_value() else None
? So that you support both the None
and not None
cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you are probably right, I will double-check (see also my reply below to your "taking a step back" comment wrt. mixing None
and not None
cases).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, after deeper investigation, it is not that simple to handle the "mixed" case in Tensorflow backend (when a generator sometimes provides None
and sometimes a tensor for optional inputs). In particular, the problem with your suggestion is that i.has_value()
can be a symbolic tensor (when code is traced by TF with the default run_eagerly=False
), which cannot be evaluated as a Python boolean. A natural solution would be to replace your Python condition with something like tf.cond(i.has_value(), lambda: i.get_value(), lambda: None)
, unfortunately TF control flow operators don't support None
as return values... I believe I found a more complex solution that does the trick to cover this Tensorflow edge case, but I think it deserves a dedicated follow-up PR and would suggest first merging this current one - please see this comment for more details. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, sounds good.
@@ -199,6 +203,8 @@ def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): | |||
""" | |||
from keras.src.utils.module_utils import tensorflow as tf | |||
|
|||
if keras_tensor is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this actually ever happen?
My assumption was that this would need to handle non-None inputs that have optional=True
on them (this might require some changes), and then create a tf.OptionalSpec(<the actual tensorspec for the KerasTensor per the code below>)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does actually happen, even if the reason is not intuitive: your assumption makes a lot of sense (ideally we would like optional inputs to be represented by KerasTensor
with optional=True
like in the model), unfortunately all the code in data_adapters is independent from the model, and the data spec is solely inferred from the first batches of received data (typically here)... which seems indeed a bit brittle and prone to some "hidden" constraints for the first batches of the dataset (e.g. see this error message).
Since it is not possible to infer a proper KerasTensor
just from a received None
value, the trick I am using is to keep it as None
(by using the newly introduced none_is_leaf=False
inside get_keras_tensor_spec
), which explains then that the line of code you mention is actually needed.
Thank you very much @hertschuh for your insightful review! To answer your "taking a step back" comment, for now the goal is at least to enable optional inputs in model fit/evaluate/predict when they are always or never I will continue to investigate and see if a solution for mixed values ( |
You're right, the data spec and and the model inputs are disconnected, which, as you point out, is the source of a number of shortcomings. It might not be possible to mix |
Actually, the current solution in this PR already works for model fit/evaluate/predict across Jax/TF/Torch backends on single data batches as well as batches provided by a generator - even when the latter generates "mixed" optional input values ( Since my latest comment, I experimented a lot and finally did find a solution that also covers this last edge case (generator with "mixed" optional input values with TF backend). But the solution is more complex than expected and I would rather add it in a dedicated follow-up PR to enable specific discussions about it. So, if you agree @hertschuh, I would suggest merging this PR first, which already covers most cases. To that end, I just added in this commit some unit tests (within model_test.py as discussed above). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @neo-alex , this looks good!
Looking forward to the next PR.
Also, one question, what if people do use tf.experimental.Optional
if a tf.data.Dataset
, but then use a different backend than TF (let's say JAX), shouldn't the conversion happening in TFDatasetAdapter
also convert from tf.experimental.Optional
to None
(if you agree, this can be a separate PR).
It looks like it's failing on GPU. Do you understand why? |
Yes, it seems to be due to the XLA JIT-compilation (automatically activated on GPU but not on CPU) - when I force I think the trick would be to move the Optional->tensor/None conversion just outside of the one_step_on_data function, which is the part that Keras possibly JIT compiles. I believe this should work since a couple of I will try to push an update for this in the next couple of days and let you know. |
Great, thanks! |
Alright, I just pushed the suggested fix for JIT compilation compatibility (I believe everything should be fine this time, even on GPU). I also made sure that the follow-up PR covering the last TF edge case is still compatible with this change - I will submit it as soon as you merge this one @hertschuh |
Here is an example of a model with 2 inputs, the second one being optional:
With this definition, the model can be called in Jax/TF/Torch without issue, even when input2 is None:
It is even possible to train on a batch when input2 is None:
However, doing the same with model.fit API is currently failing on all backends:
The purpose of this PR is to fix this issue (on Jax/TF/Torch), so that the last code block above becomes possible (btw. I could add 1 or 2 unit tests along those lines to demonstrate the fix but I am unsure where would be the best place for it... model_test maybe?).