1313import multiprocessing .pool
1414from functools import partial
1515
16- from . import get_keras_submodule
17-
18- backend = get_keras_submodule ('backend' )
19- keras_utils = get_keras_submodule ('utils' )
20-
2116try :
2217 from PIL import ImageEnhance
2318 from PIL import Image as pil_image
@@ -349,7 +344,7 @@ def flip_axis(x, axis):
349344 return x
350345
351346
352- def array_to_img (x , data_format = None , scale = True ):
347+ def array_to_img (x , data_format = 'channels_last' , scale = True , dtype = 'float32' ):
353348 """Converts a 3D Numpy array to a PIL Image instance.
354349
355350 # Arguments
@@ -358,6 +353,7 @@ def array_to_img(x, data_format=None, scale=True):
358353 either "channels_first" or "channels_last".
359354 scale: Whether to rescale image values
360355 to be within `[0, 255]`.
356+ dtype: Dtype to use.
361357
362358 # Returns
363359 A PIL Image instance.
@@ -369,13 +365,11 @@ def array_to_img(x, data_format=None, scale=True):
369365 if pil_image is None :
370366 raise ImportError ('Could not import PIL.Image. '
371367 'The use of `array_to_img` requires PIL.' )
372- x = np .asarray (x , dtype = backend . floatx () )
368+ x = np .asarray (x , dtype = dtype )
373369 if x .ndim != 3 :
374370 raise ValueError ('Expected image array to have rank 3 (single image). '
375371 'Got array with shape: %s' % (x .shape ,))
376372
377- if data_format is None :
378- data_format = backend .image_data_format ()
379373 if data_format not in {'channels_first' , 'channels_last' }:
380374 raise ValueError ('Invalid data_format: %s' % data_format )
381375
@@ -403,28 +397,27 @@ def array_to_img(x, data_format=None, scale=True):
403397 raise ValueError ('Unsupported channel number: %s' % (x .shape [2 ],))
404398
405399
406- def img_to_array (img , data_format = None ):
400+ def img_to_array (img , data_format = 'channels_last' , dtype = 'float32' ):
407401 """Converts a PIL Image instance to a Numpy array.
408402
409403 # Arguments
410404 img: PIL Image instance.
411405 data_format: Image data format,
412406 either "channels_first" or "channels_last".
407+ dtype: Dtype to use for the returned array.
413408
414409 # Returns
415410 A 3D Numpy array.
416411
417412 # Raises
418413 ValueError: if invalid `img` or `data_format` is passed.
419414 """
420- if data_format is None :
421- data_format = backend .image_data_format ()
422415 if data_format not in {'channels_first' , 'channels_last' }:
423416 raise ValueError ('Unknown data_format: %s' % data_format )
424417 # Numpy array x has format (height, width, channel)
425418 # or (channel, height, width)
426419 # but original PIL image has format (width, height, channel)
427- x = np .asarray (img , dtype = backend . floatx () )
420+ x = np .asarray (img , dtype = dtype )
428421 if len (x .shape ) == 3 :
429422 if data_format == 'channels_first' :
430423 x = x .transpose (2 , 0 , 1 )
@@ -440,9 +433,10 @@ def img_to_array(img, data_format=None):
440433
441434def save_img (path ,
442435 x ,
443- data_format = None ,
436+ data_format = 'channels_last' ,
444437 file_format = None ,
445- scale = True , ** kwargs ):
438+ scale = True ,
439+ ** kwargs ):
446440 """Saves an image stored as a Numpy array to a path or file object.
447441
448442 # Arguments
@@ -602,6 +596,7 @@ class ImageDataGenerator(object):
602596 If you never set it, then it will be "channels_last".
603597 validation_split: Float. Fraction of images reserved for validation
604598 (strictly between 0 and 1).
599+ dtype: Dtype to use for the generated arrays.
605600
606601 # Examples
607602 Example of using `.flow(x, y)`:
@@ -728,10 +723,9 @@ def __init__(self,
728723 vertical_flip = False ,
729724 rescale = None ,
730725 preprocessing_function = None ,
731- data_format = None ,
732- validation_split = 0.0 ):
733- if data_format is None :
734- data_format = backend .image_data_format ()
726+ data_format = 'channels_last' ,
727+ validation_split = 0.0 ,
728+ dtype = 'float32' ):
735729 self .featurewise_center = featurewise_center
736730 self .samplewise_center = samplewise_center
737731 self .featurewise_std_normalization = featurewise_std_normalization
@@ -751,6 +745,7 @@ def __init__(self,
751745 self .vertical_flip = vertical_flip
752746 self .rescale = rescale
753747 self .preprocessing_function = preprocessing_function
748+ self .dtype = dtype
754749
755750 if data_format not in {'channels_last' , 'channels_first' }:
756751 raise ValueError (
@@ -983,7 +978,7 @@ def standardize(self, x):
983978 if self .samplewise_center :
984979 x -= np .mean (x , keepdims = True )
985980 if self .samplewise_std_normalization :
986- x /= (np .std (x , keepdims = True ) + backend . epsilon () )
981+ x /= (np .std (x , keepdims = True ) + 1e-6 )
987982
988983 if self .featurewise_center :
989984 if self .mean is not None :
@@ -995,7 +990,7 @@ def standardize(self, x):
995990 'first by calling `.fit(numpy_data)`.' )
996991 if self .featurewise_std_normalization :
997992 if self .std is not None :
998- x /= (self .std + backend . epsilon () )
993+ x /= (self .std + 1e-6 )
999994 else :
1000995 warnings .warn ('This ImageDataGenerator specifies '
1001996 '`featurewise_std_normalization`, '
@@ -1202,7 +1197,7 @@ def fit(self, x,
12021197 this is how many augmentation passes over the data to use.
12031198 seed: Int (default: None). Random seed.
12041199 """
1205- x = np .asarray (x , dtype = backend . floatx () )
1200+ x = np .asarray (x , dtype = self . dtype )
12061201 if x .ndim != 4 :
12071202 raise ValueError ('Input to `.fit()` should have rank 4. '
12081203 'Got array with shape: ' + str (x .shape ))
@@ -1225,7 +1220,7 @@ def fit(self, x,
12251220 if augment :
12261221 ax = np .zeros (
12271222 tuple ([rounds * x .shape [0 ]] + list (x .shape )[1 :]),
1228- dtype = backend . floatx () )
1223+ dtype = self . dtype )
12291224 for r in range (rounds ):
12301225 for i in range (x .shape [0 ]):
12311226 ax [i + r * x .shape [0 ]] = self .random_transform (x [i ])
@@ -1243,7 +1238,7 @@ def fit(self, x,
12431238 broadcast_shape = [1 , 1 , 1 ]
12441239 broadcast_shape [self .channel_axis - 1 ] = x .shape [self .channel_axis ]
12451240 self .std = np .reshape (self .std , broadcast_shape )
1246- x /= (self .std + backend . epsilon () )
1241+ x /= (self .std + 1e-6 )
12471242
12481243 if self .zca_whitening :
12491244 if scipy is None :
@@ -1257,7 +1252,7 @@ def fit(self, x,
12571252 self .principal_components = (u * s_inv ).dot (u .T )
12581253
12591254
1260- class Iterator (keras_utils . Sequence ):
1255+ class Iterator (object ):
12611256 """Base class for image data iterators.
12621257
12631258 Every `Iterator` must implement the `_get_batches_of_transformed_samples`
@@ -1375,13 +1370,15 @@ class NumpyArrayIterator(Iterator):
13751370 (if `save_to_dir` is set).
13761371 subset: Subset of data (`"training"` or `"validation"`) if
13771372 validation_split is set in ImageDataGenerator.
1373+ dtype: Dtype to use for the generated arrays.
13781374 """
13791375
13801376 def __init__ (self , x , y , image_data_generator ,
13811377 batch_size = 32 , shuffle = False , sample_weight = None ,
1382- seed = None , data_format = None ,
1378+ seed = None , data_format = 'channels_last' ,
13831379 save_to_dir = None , save_prefix = '' , save_format = 'png' ,
1384- subset = None ):
1380+ subset = None , dtype = 'float32' ):
1381+ self .dtype = dtype
13851382 if (type (x ) is tuple ) or (type (x ) is list ):
13861383 if type (x [1 ]) is not list :
13871384 x_misc = [np .asarray (x [1 ])]
@@ -1423,9 +1420,7 @@ def __init__(self, x, y, image_data_generator,
14231420 x_misc = [np .asarray (xx [split_idx :]) for xx in x_misc ]
14241421 if y is not None :
14251422 y = y [split_idx :]
1426- if data_format is None :
1427- data_format = backend .image_data_format ()
1428- self .x = np .asarray (x , dtype = backend .floatx ())
1423+ self .x = np .asarray (x , dtype = self .dtype )
14291424 self .x_misc = x_misc
14301425 if self .x .ndim != 4 :
14311426 raise ValueError ('Input data in `NumpyArrayIterator` '
@@ -1461,12 +1456,12 @@ def __init__(self, x, y, image_data_generator,
14611456
14621457 def _get_batches_of_transformed_samples (self , index_array ):
14631458 batch_x = np .zeros (tuple ([len (index_array )] + list (self .x .shape )[1 :]),
1464- dtype = backend . floatx () )
1459+ dtype = self . dtype )
14651460 for i , j in enumerate (index_array ):
14661461 x = self .x [j ]
14671462 params = self .image_data_generator .get_random_transform (x .shape )
14681463 x = self .image_data_generator .apply_transform (
1469- x .astype (backend . floatx () ), params )
1464+ x .astype (self . dtype ), params )
14701465 x = self .image_data_generator .standardize (x )
14711466 batch_x [i ] = x
14721467
@@ -1654,19 +1649,19 @@ class DirectoryIterator(Iterator):
16541649 If PIL version 1.1.3 or newer is installed, "lanczos" is also
16551650 supported. If PIL version 3.4.0 or newer is installed, "box" and
16561651 "hamming" are also supported. By default, "nearest" is used.
1652+ dtype: Dtype to use for generated arrays.
16571653 """
16581654
16591655 def __init__ (self , directory , image_data_generator ,
16601656 target_size = (256 , 256 ), color_mode = 'rgb' ,
16611657 classes = None , class_mode = 'categorical' ,
16621658 batch_size = 32 , shuffle = True , seed = None ,
1663- data_format = None ,
1659+ data_format = 'channels_last' ,
16641660 save_to_dir = None , save_prefix = '' , save_format = 'png' ,
16651661 follow_links = False ,
16661662 subset = None ,
1667- interpolation = 'nearest' ):
1668- if data_format is None :
1669- data_format = backend .image_data_format ()
1663+ interpolation = 'nearest' ,
1664+ dtype = 'float32' ):
16701665 self .directory = directory
16711666 self .image_data_generator = image_data_generator
16721667 self .target_size = tuple (target_size )
@@ -1702,6 +1697,7 @@ def __init__(self, directory, image_data_generator,
17021697 self .save_prefix = save_prefix
17031698 self .save_format = save_format
17041699 self .interpolation = interpolation
1700+ self .dtype = dtype
17051701
17061702 if subset is not None :
17071703 validation_split = self .image_data_generator ._validation_split
@@ -1769,7 +1765,7 @@ def __init__(self, directory, image_data_generator,
17691765 def _get_batches_of_transformed_samples (self , index_array ):
17701766 batch_x = np .zeros (
17711767 (len (index_array ),) + self .image_shape ,
1772- dtype = backend . floatx () )
1768+ dtype = self . dtype )
17731769 # build batch of image data
17741770 for i , j in enumerate (index_array ):
17751771 fname = self .filenames [j ]
@@ -1802,11 +1798,11 @@ def _get_batches_of_transformed_samples(self, index_array):
18021798 elif self .class_mode == 'sparse' :
18031799 batch_y = self .classes [index_array ]
18041800 elif self .class_mode == 'binary' :
1805- batch_y = self .classes [index_array ].astype (backend . floatx () )
1801+ batch_y = self .classes [index_array ].astype (self . dtype )
18061802 elif self .class_mode == 'categorical' :
18071803 batch_y = np .zeros (
18081804 (len (batch_x ), self .num_classes ),
1809- dtype = backend . floatx () )
1805+ dtype = self . dtype )
18101806 for i , label in enumerate (self .classes [index_array ]):
18111807 batch_y [i , label ] = 1.
18121808 else :
0 commit comments