diff --git a/keras_preprocessing/image/image_data_generator.py b/keras_preprocessing/image/image_data_generator.py index 5a44d31b..8f9a454d 100644 --- a/keras_preprocessing/image/image_data_generator.py +++ b/keras_preprocessing/image/image_data_generator.py @@ -427,7 +427,8 @@ def flow(self, save_to_dir=save_to_dir, save_prefix=save_prefix, save_format=save_format, - subset=subset + subset=subset, + dtype=x.dtype ) def flow_from_directory(self, diff --git a/keras_preprocessing/image/numpy_array_iterator.py b/keras_preprocessing/image/numpy_array_iterator.py index f03434ba..2a45bf16 100644 --- a/keras_preprocessing/image/numpy_array_iterator.py +++ b/keras_preprocessing/image/numpy_array_iterator.py @@ -109,7 +109,7 @@ def __init__(self, if y is not None: y = y[split_idx:] - self.x = np.asarray(x, dtype=self.dtype) + self.x = np.asanyarray(x, dtype=self.dtype) self.x_misc = x_misc if self.x.ndim != 4: raise ValueError('Input data in `NumpyArrayIterator` '