Skip to content

Commit 19367bc

Browse files
authored
Fix support for optional inputs in model.fit (#21548)
* Add none_is_leaf option in tree.map_structure * Support None for optional inputs in model.fit * Fix formatting (line too long) * Support None for optional inputs in model.evaluate & model.predict * Fix formatting * Improve none_is_leaf docstring in tree.map_structure * Improve error message for structure mismatch * Simplify conversion from None to TF Optional (for TF dataset) * Add tests for model fit/evaluate/predict with optional inputs * Improve model.compile params in tests * Enable JIT compilation
1 parent bf85450 commit 19367bc

File tree

11 files changed

+135
-21
lines changed

11 files changed

+135
-21
lines changed

keras/src/backend/tensorflow/trainer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import functools
23
import warnings
34

45
import numpy as np
@@ -107,6 +108,21 @@ def predict_step(self, data):
107108
y_pred = self(x)
108109
return y_pred
109110

111+
def _autoconvert_optionals(self, step_func):
112+
# Wrapper converting (nested) TF Optional in input data to None
113+
@functools.wraps(step_func)
114+
def wrapper(data):
115+
converted_data = tree.map_structure(
116+
lambda i: (
117+
None if isinstance(i, tf.experimental.Optional) else i
118+
),
119+
data,
120+
)
121+
result = step_func(converted_data)
122+
return result
123+
124+
return wrapper
125+
110126
def _make_function(self, step_function):
111127
@tf.autograph.experimental.do_not_convert
112128
def one_step_on_data(data):
@@ -125,6 +141,7 @@ def one_step_on_data(data):
125141
reduce_retracing=True,
126142
jit_compile=self.jit_compile,
127143
)
144+
one_step_on_data = self._autoconvert_optionals(one_step_on_data)
128145

129146
@tf.autograph.experimental.do_not_convert
130147
def multi_step_on_iterator(iterator):
@@ -253,6 +270,7 @@ def one_step_on_data(data):
253270
one_step_on_data = tf.function(
254271
one_step_on_data, reduce_retracing=True, jit_compile=True
255272
)
273+
one_step_on_data = self._autoconvert_optionals(one_step_on_data)
256274

257275
@tf.autograph.experimental.do_not_convert
258276
def one_step_on_data_distributed(data):

keras/src/models/model_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,23 @@ def call(self, x):
160160
return model
161161

162162

163+
def _get_model_optional_inputs():
164+
class OptionalInputLayer(layers.Layer):
165+
def __init__(self):
166+
super().__init__()
167+
self.dense = layers.Dense(2)
168+
169+
def call(self, a, b=None):
170+
x = a if b is None else a + b
171+
return self.dense(x)
172+
173+
x1 = Input((2,), name="x1")
174+
x2 = Input((2,), name="x2", optional=True)
175+
y = OptionalInputLayer()(x1, x2)
176+
model = Model({"x1": x1, "x2": x2}, y)
177+
return model
178+
179+
163180
def _get_variable_value_by_path(variables, path):
164181
for v in variables:
165182
if v.path == path:
@@ -1222,6 +1239,38 @@ def test_functional_deeply_nested_outputs_struct_losses(self):
12221239
)
12231240
self.assertListEqual(hist_keys, ref_keys)
12241241

1242+
@parameterized.named_parameters(
1243+
("optional_none", True), ("optional_tensor", False)
1244+
)
1245+
def test_functional_optional_inputs(self, is_optional_none):
1246+
model = _get_model_optional_inputs()
1247+
x1 = np.ones((2, 2))
1248+
x2 = None if is_optional_none else np.ones((2, 2))
1249+
y_true = np.ones((2, 2))
1250+
1251+
model.compile(loss="mse", optimizer="adam")
1252+
model.fit(x={"x1": x1, "x2": x2}, y=y_true)
1253+
model.evaluate(x={"x1": x1, "x2": x2}, y=y_true)
1254+
model.predict(x={"x1": x1, "x2": x2})
1255+
1256+
@parameterized.named_parameters(
1257+
("optional_none", True), ("optional_tensor", False)
1258+
)
1259+
def test_functional_optional_inputs_generator(self, is_optional_none):
1260+
model = _get_model_optional_inputs()
1261+
x1 = np.ones((2, 2))
1262+
x2 = None if is_optional_none else np.ones((2, 2))
1263+
y_true = np.ones((2, 2))
1264+
1265+
def data_generator(with_y=True):
1266+
for _ in range(4):
1267+
yield ({"x1": x1, "x2": x2},) + ((y_true,) if with_y else ())
1268+
1269+
model.compile(loss="mse", optimizer="adam")
1270+
model.fit(data_generator())
1271+
model.evaluate(data_generator())
1272+
model.predict(data_generator(with_y=False))
1273+
12251274
def test_export_error(self):
12261275
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
12271276
model = _get_model()

keras/src/trainers/data_adapters/array_data_adapter.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def __init__(
7676
inputs = data_adapter_utils.pack_x_y_sample_weight(x, y, sample_weight)
7777

7878
data_adapter_utils.check_data_cardinality(inputs)
79-
num_samples = set(i.shape[0] for i in tree.flatten(inputs)).pop()
79+
num_samples = set(
80+
i.shape[0] for i in tree.flatten(inputs) if i is not None
81+
).pop()
8082
self._num_samples = num_samples
8183
self._inputs = inputs
8284

@@ -269,7 +271,9 @@ def slice_and_convert(sliceable):
269271
x = convert_to_tensor(x)
270272
return x
271273

272-
return tree.map_structure(slice_and_convert, self.array)
274+
return tree.map_structure(
275+
slice_and_convert, self.array, none_is_leaf=False
276+
)
273277

274278
def __len__(self):
275279
return len(self.array[0])
@@ -337,7 +341,9 @@ def _get_iterator(self, slice_and_convert_fn, inputs):
337341
slice_indices_and_convert_fn = functools.partial(
338342
slice_and_convert_fn, indices=indices
339343
)
340-
yield tree.map_structure(slice_indices_and_convert_fn, inputs)
344+
yield tree.map_structure(
345+
slice_indices_and_convert_fn, inputs, none_is_leaf=False
346+
)
341347

342348
@property
343349
def num_batches(self):

keras/src/trainers/data_adapters/data_adapter_utils.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def list_to_tuple(maybe_list):
101101

102102

103103
def check_data_cardinality(data):
104-
num_samples = set(int(i.shape[0]) for i in tree.flatten(data))
104+
num_samples = set(
105+
int(i.shape[0]) for i in tree.flatten(data) if i is not None
106+
)
105107
if len(num_samples) > 1:
106108
msg = (
107109
"Data cardinality is ambiguous. "
@@ -186,7 +188,9 @@ def get_single_tensor_spec(*tensors):
186188
else:
187189
return backend.KerasTensor(shape=shape, dtype=dtype)
188190

189-
return tree.map_structure(get_single_tensor_spec, *batches)
191+
return tree.map_structure(
192+
get_single_tensor_spec, *batches, none_is_leaf=False
193+
)
190194

191195

192196
def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True):
@@ -199,6 +203,8 @@ def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True):
199203
"""
200204
from keras.src.utils.module_utils import tensorflow as tf
201205

206+
if keras_tensor is None:
207+
return tf.OptionalSpec(None)
202208
if not isinstance(keras_tensor, backend.KerasTensor):
203209
raise TypeError(
204210
f"Expected a KerasTensor, but got {keras_tensor} of type "
@@ -252,7 +258,9 @@ def convert_to_jax_compatible(x):
252258
return np.asarray(x)
253259

254260
for batch in iterable:
255-
yield tree.map_structure(convert_to_jax_compatible, batch)
261+
yield tree.map_structure(
262+
convert_to_jax_compatible, batch, none_is_leaf=False
263+
)
256264

257265

258266
def get_numpy_iterator(iterable):
@@ -268,7 +276,7 @@ def convert_to_numpy(x):
268276
return x
269277

270278
for batch in iterable:
271-
yield tree.map_structure(convert_to_numpy, batch)
279+
yield tree.map_structure(convert_to_numpy, batch, none_is_leaf=False)
272280

273281

274282
def get_torch_dataloader(iterable):
@@ -282,7 +290,9 @@ def __init__(self, iterable):
282290

283291
def __iter__(self):
284292
for batch in self.iterable:
285-
yield tree.map_structure(convert_to_tensor, batch)
293+
yield tree.map_structure(
294+
convert_to_tensor, batch, none_is_leaf=False
295+
)
286296

287297
dataset = ConverterIterableDataset(iterable)
288298
# `batch_size=None` indicates that we should not re-batch

keras/src/trainers/data_adapters/generator_data_adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def get_tf_dataset(self):
3232
from keras.src.utils.module_utils import tensorflow as tf
3333

3434
def convert_to_tf(x, spec):
35+
if x is None:
36+
return tf.experimental.Optional.empty(None)
3537
if data_adapter_utils.is_scipy_sparse(x):
3638
x = data_adapter_utils.scipy_sparse_to_tf_sparse(x)
3739
elif data_adapter_utils.is_jax_sparse(x):

keras/src/trainers/data_adapters/grain_dataset_adapter.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def convert_to_numpy(x):
8080

8181
class ConvertToNumpy(grain.transforms.Map):
8282
def map(self, x):
83-
return tree.map_structure(convert_to_numpy, x)
83+
return tree.map_structure(
84+
convert_to_numpy, x, none_is_leaf=False
85+
)
8486

8587
if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)):
8688
dataset = self._dataset.map(ConvertToNumpy())
@@ -109,7 +111,9 @@ def convert_to_jax_compatible(x):
109111

110112
class ConvertToJaxCompatible(grain.transforms.Map):
111113
def map(self, x):
112-
return tree.map_structure(convert_to_jax_compatible, x)
114+
return tree.map_structure(
115+
convert_to_jax_compatible, x, none_is_leaf=False
116+
)
113117

114118
if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)):
115119
dataset = self._dataset.map(ConvertToJaxCompatible())
@@ -131,6 +135,8 @@ def map(self, x):
131135

132136
def get_tf_dataset(self):
133137
def convert_to_tf(x):
138+
if x is None:
139+
return tf.experimental.Optional.empty(None)
134140
if data_adapter_utils.is_scipy_sparse(x):
135141
x = data_adapter_utils.scipy_sparse_to_tf_sparse(x)
136142
elif data_adapter_utils.is_jax_sparse(x):

keras/src/trainers/data_adapters/tf_dataset_adapter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def get_numpy_iterator(self):
3838
from keras.src.backend.tensorflow.core import convert_to_numpy
3939

4040
for batch in self._dataset:
41-
yield tree.map_structure(convert_to_numpy, batch)
41+
yield tree.map_structure(
42+
convert_to_numpy, batch, none_is_leaf=False
43+
)
4244

4345
def get_jax_iterator(self):
4446
from keras.src.backend.tensorflow.core import convert_to_numpy
@@ -52,7 +54,7 @@ def convert_to_jax(x):
5254
return convert_to_numpy(x)
5355

5456
for batch in self._dataset:
55-
yield tree.map_structure(convert_to_jax, batch)
57+
yield tree.map_structure(convert_to_jax, batch, none_is_leaf=False)
5658

5759
def get_tf_dataset(self):
5860
return self._dataset

keras/src/trainers/data_adapters/torch_data_loader_adapter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def get_numpy_iterator(self):
3535
for batch in self._dataloader:
3636
# shared memory using `np.asarray`
3737
yield tuple(
38-
tree.map_structure(lambda x: np.asarray(x.cpu()), batch)
38+
tree.map_structure(
39+
lambda x: np.asarray(x.cpu()), batch, none_is_leaf=False
40+
)
3941
)
4042

4143
def get_jax_iterator(self):

keras/src/tree/dmtree_impl.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,32 @@ def flatten_with_path(structure):
194194
return flattened
195195

196196

197-
def map_structure(func, *structures):
197+
def map_structure(func, *structures, none_is_leaf=True):
198198
if not callable(func):
199199
raise TypeError(
200200
f"`func` must be callable, got {func} of type {type(func)}"
201201
)
202202

203+
map_func = func
204+
if not none_is_leaf:
205+
206+
def func_skipping_none(*args):
207+
# Check if the reference entry (first one) is None
208+
if args[0] is None:
209+
if not all(s is None for s in args):
210+
raise ValueError(
211+
"Structure mismatch: some arguments are None, others "
212+
f"are not. Received arguments: {args}."
213+
)
214+
return None
215+
return func(*args)
216+
217+
map_func = func_skipping_none
218+
203219
def func_traverse_wrapper(s):
204220
if is_nested(s):
205221
return None
206-
ret = func(s)
222+
ret = map_func(s)
207223
if ret is None:
208224
return dmtree.MAP_TO_NONE
209225
return ret
@@ -212,7 +228,7 @@ def func_traverse_wrapper(s):
212228
return traverse(func_traverse_wrapper, structures[0])
213229

214230
with TypeErrorRemapping():
215-
return dmtree.map_structure(func, *structures)
231+
return dmtree.map_structure(map_func, *structures)
216232

217233

218234
def map_structure_up_to(shallow_structure, func, *structures):

keras/src/tree/optree_impl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@ def flatten_with_path(structure):
9393
return list(zip(paths, leaves))
9494

9595

96-
def map_structure(func, *structures):
96+
def map_structure(func, *structures, none_is_leaf=True):
9797
if not structures:
9898
raise ValueError("Must provide at least one structure")
9999

100100
# Add check for same structures, otherwise optree just maps to shallowest.
101101
def func_with_check(*args):
102102
if not all(
103-
optree.tree_is_leaf(s, none_is_leaf=True, namespace="keras")
103+
optree.tree_is_leaf(s, none_is_leaf=none_is_leaf, namespace="keras")
104104
for s in args
105105
):
106106
raise ValueError("Structures don't have the same nested structure.")
@@ -109,7 +109,7 @@ def func_with_check(*args):
109109
map_func = func_with_check if len(structures) > 1 else func
110110

111111
return optree.tree_map(
112-
map_func, *structures, none_is_leaf=True, namespace="keras"
112+
map_func, *structures, none_is_leaf=none_is_leaf, namespace="keras"
113113
)
114114

115115

0 commit comments

Comments
 (0)