Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit dad7fcc

Browse files
authored
Merge pull request #40 from keras-team/redesign
Remove reliance on Keras submodule imports.
2 parents b9d1424 + 3075487 commit dad7fcc

File tree

6 files changed

+45
-90
lines changed

6 files changed

+45
-90
lines changed

.travis.yml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,12 @@ matrix:
55
include:
66
- python: 2.7
77
env: TEST_MODE=PEP8
8-
- python: 2.7
9-
env: TEST_MODE=INTEGRATION_TESTS
108
- python: 2.7
119
env: KERAS_HEAD=true
1210
- python: 3.6
1311
env: KERAS_HEAD=true
1412
- python: 2.7
1513
- python: 3.6
16-
- python: 3.6
17-
env: TEST_MODE=INTEGRATION_TESTS
1814
install:
1915
# code below is taken from http://conda.pydata.org/docs/travis.html
2016
# We do this conditionally because it saves us some downloading if the
@@ -57,8 +53,6 @@ install:
5753
script:
5854
- if [[ "$TEST_MODE" == "PEP8" ]]; then
5955
PYTHONPATH=$PWD:$PYTHONPATH py.test --pep8 -m pep8 -n0;
60-
elif [[ "$TEST_MODE" == "INTEGRATION_TESTS" ]]; then
61-
PYTHONPATH=$PWD:$PYTHONPATH py.test tests/integration_test.py;
6256
else
63-
PYTHONPATH=$PWD:$PYTHONPATH py.test tests/ --cov-config .coveragerc --cov=keras_preprocessing tests/ --ignore=tests/integration_test.py;
57+
PYTHONPATH=$PWD:$PYTHONPATH py.test tests/ --cov-config .coveragerc --cov=keras_preprocessing tests/;
6458
fi

keras_preprocessing/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99

1010

1111
def set_keras_submodules(backend, utils):
12+
# Deprecated, will be removed in the future.
1213
global _KERAS_BACKEND
1314
global _KERAS_UTILS
1415
_KERAS_BACKEND = backend
1516
_KERAS_UTILS = utils
1617

1718

1819
def get_keras_submodule(name):
20+
# Deprecated, will be removed in the future.
1921
if name not in {'backend', 'utils'}:
2022
raise ImportError(
2123
'Can only retrieve "backend" and "utils". '
@@ -36,3 +38,5 @@ def get_keras_submodule(name):
3638
return _KERAS_BACKEND
3739
elif name == 'utils':
3840
return _KERAS_UTILS
41+
42+
__version__ = '1.0.3'

keras_preprocessing/image.py

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,6 @@
1313
import multiprocessing.pool
1414
from functools import partial
1515

16-
from . import get_keras_submodule
17-
18-
backend = get_keras_submodule('backend')
19-
keras_utils = get_keras_submodule('utils')
20-
2116
try:
2217
from PIL import ImageEnhance
2318
from PIL import Image as pil_image
@@ -349,7 +344,7 @@ def flip_axis(x, axis):
349344
return x
350345

351346

352-
def array_to_img(x, data_format=None, scale=True):
347+
def array_to_img(x, data_format='channels_last', scale=True, dtype='float32'):
353348
"""Converts a 3D Numpy array to a PIL Image instance.
354349
355350
# Arguments
@@ -358,6 +353,7 @@ def array_to_img(x, data_format=None, scale=True):
358353
either "channels_first" or "channels_last".
359354
scale: Whether to rescale image values
360355
to be within `[0, 255]`.
356+
dtype: Dtype to use.
361357
362358
# Returns
363359
A PIL Image instance.
@@ -369,13 +365,11 @@ def array_to_img(x, data_format=None, scale=True):
369365
if pil_image is None:
370366
raise ImportError('Could not import PIL.Image. '
371367
'The use of `array_to_img` requires PIL.')
372-
x = np.asarray(x, dtype=backend.floatx())
368+
x = np.asarray(x, dtype=dtype)
373369
if x.ndim != 3:
374370
raise ValueError('Expected image array to have rank 3 (single image). '
375371
'Got array with shape: %s' % (x.shape,))
376372

377-
if data_format is None:
378-
data_format = backend.image_data_format()
379373
if data_format not in {'channels_first', 'channels_last'}:
380374
raise ValueError('Invalid data_format: %s' % data_format)
381375

@@ -403,28 +397,27 @@ def array_to_img(x, data_format=None, scale=True):
403397
raise ValueError('Unsupported channel number: %s' % (x.shape[2],))
404398

405399

406-
def img_to_array(img, data_format=None):
400+
def img_to_array(img, data_format='channels_last', dtype='float32'):
407401
"""Converts a PIL Image instance to a Numpy array.
408402
409403
# Arguments
410404
img: PIL Image instance.
411405
data_format: Image data format,
412406
either "channels_first" or "channels_last".
407+
dtype: Dtype to use for the returned array.
413408
414409
# Returns
415410
A 3D Numpy array.
416411
417412
# Raises
418413
ValueError: if invalid `img` or `data_format` is passed.
419414
"""
420-
if data_format is None:
421-
data_format = backend.image_data_format()
422415
if data_format not in {'channels_first', 'channels_last'}:
423416
raise ValueError('Unknown data_format: %s' % data_format)
424417
# Numpy array x has format (height, width, channel)
425418
# or (channel, height, width)
426419
# but original PIL image has format (width, height, channel)
427-
x = np.asarray(img, dtype=backend.floatx())
420+
x = np.asarray(img, dtype=dtype)
428421
if len(x.shape) == 3:
429422
if data_format == 'channels_first':
430423
x = x.transpose(2, 0, 1)
@@ -440,9 +433,10 @@ def img_to_array(img, data_format=None):
440433

441434
def save_img(path,
442435
x,
443-
data_format=None,
436+
data_format='channels_last',
444437
file_format=None,
445-
scale=True, **kwargs):
438+
scale=True,
439+
**kwargs):
446440
"""Saves an image stored as a Numpy array to a path or file object.
447441
448442
# Arguments
@@ -602,6 +596,7 @@ class ImageDataGenerator(object):
602596
If you never set it, then it will be "channels_last".
603597
validation_split: Float. Fraction of images reserved for validation
604598
(strictly between 0 and 1).
599+
dtype: Dtype to use for the generated arrays.
605600
606601
# Examples
607602
Example of using `.flow(x, y)`:
@@ -728,10 +723,9 @@ def __init__(self,
728723
vertical_flip=False,
729724
rescale=None,
730725
preprocessing_function=None,
731-
data_format=None,
732-
validation_split=0.0):
733-
if data_format is None:
734-
data_format = backend.image_data_format()
726+
data_format='channels_last',
727+
validation_split=0.0,
728+
dtype='float32'):
735729
self.featurewise_center = featurewise_center
736730
self.samplewise_center = samplewise_center
737731
self.featurewise_std_normalization = featurewise_std_normalization
@@ -751,6 +745,7 @@ def __init__(self,
751745
self.vertical_flip = vertical_flip
752746
self.rescale = rescale
753747
self.preprocessing_function = preprocessing_function
748+
self.dtype = dtype
754749

755750
if data_format not in {'channels_last', 'channels_first'}:
756751
raise ValueError(
@@ -983,7 +978,7 @@ def standardize(self, x):
983978
if self.samplewise_center:
984979
x -= np.mean(x, keepdims=True)
985980
if self.samplewise_std_normalization:
986-
x /= (np.std(x, keepdims=True) + backend.epsilon())
981+
x /= (np.std(x, keepdims=True) + 1e-6)
987982

988983
if self.featurewise_center:
989984
if self.mean is not None:
@@ -995,7 +990,7 @@ def standardize(self, x):
995990
'first by calling `.fit(numpy_data)`.')
996991
if self.featurewise_std_normalization:
997992
if self.std is not None:
998-
x /= (self.std + backend.epsilon())
993+
x /= (self.std + 1e-6)
999994
else:
1000995
warnings.warn('This ImageDataGenerator specifies '
1001996
'`featurewise_std_normalization`, '
@@ -1202,7 +1197,7 @@ def fit(self, x,
12021197
this is how many augmentation passes over the data to use.
12031198
seed: Int (default: None). Random seed.
12041199
"""
1205-
x = np.asarray(x, dtype=backend.floatx())
1200+
x = np.asarray(x, dtype=self.dtype)
12061201
if x.ndim != 4:
12071202
raise ValueError('Input to `.fit()` should have rank 4. '
12081203
'Got array with shape: ' + str(x.shape))
@@ -1225,7 +1220,7 @@ def fit(self, x,
12251220
if augment:
12261221
ax = np.zeros(
12271222
tuple([rounds * x.shape[0]] + list(x.shape)[1:]),
1228-
dtype=backend.floatx())
1223+
dtype=self.dtype)
12291224
for r in range(rounds):
12301225
for i in range(x.shape[0]):
12311226
ax[i + r * x.shape[0]] = self.random_transform(x[i])
@@ -1243,7 +1238,7 @@ def fit(self, x,
12431238
broadcast_shape = [1, 1, 1]
12441239
broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
12451240
self.std = np.reshape(self.std, broadcast_shape)
1246-
x /= (self.std + backend.epsilon())
1241+
x /= (self.std + 1e-6)
12471242

12481243
if self.zca_whitening:
12491244
if scipy is None:
@@ -1257,7 +1252,7 @@ def fit(self, x,
12571252
self.principal_components = (u * s_inv).dot(u.T)
12581253

12591254

1260-
class Iterator(keras_utils.Sequence):
1255+
class Iterator(object):
12611256
"""Base class for image data iterators.
12621257
12631258
Every `Iterator` must implement the `_get_batches_of_transformed_samples`
@@ -1375,13 +1370,15 @@ class NumpyArrayIterator(Iterator):
13751370
(if `save_to_dir` is set).
13761371
subset: Subset of data (`"training"` or `"validation"`) if
13771372
validation_split is set in ImageDataGenerator.
1373+
dtype: Dtype to use for the generated arrays.
13781374
"""
13791375

13801376
def __init__(self, x, y, image_data_generator,
13811377
batch_size=32, shuffle=False, sample_weight=None,
1382-
seed=None, data_format=None,
1378+
seed=None, data_format='channels_last',
13831379
save_to_dir=None, save_prefix='', save_format='png',
1384-
subset=None):
1380+
subset=None, dtype='float32'):
1381+
self.dtype = dtype
13851382
if (type(x) is tuple) or (type(x) is list):
13861383
if type(x[1]) is not list:
13871384
x_misc = [np.asarray(x[1])]
@@ -1423,9 +1420,7 @@ def __init__(self, x, y, image_data_generator,
14231420
x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc]
14241421
if y is not None:
14251422
y = y[split_idx:]
1426-
if data_format is None:
1427-
data_format = backend.image_data_format()
1428-
self.x = np.asarray(x, dtype=backend.floatx())
1423+
self.x = np.asarray(x, dtype=self.dtype)
14291424
self.x_misc = x_misc
14301425
if self.x.ndim != 4:
14311426
raise ValueError('Input data in `NumpyArrayIterator` '
@@ -1461,12 +1456,12 @@ def __init__(self, x, y, image_data_generator,
14611456

14621457
def _get_batches_of_transformed_samples(self, index_array):
14631458
batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]),
1464-
dtype=backend.floatx())
1459+
dtype=self.dtype)
14651460
for i, j in enumerate(index_array):
14661461
x = self.x[j]
14671462
params = self.image_data_generator.get_random_transform(x.shape)
14681463
x = self.image_data_generator.apply_transform(
1469-
x.astype(backend.floatx()), params)
1464+
x.astype(self.dtype), params)
14701465
x = self.image_data_generator.standardize(x)
14711466
batch_x[i] = x
14721467

@@ -1654,19 +1649,19 @@ class DirectoryIterator(Iterator):
16541649
If PIL version 1.1.3 or newer is installed, "lanczos" is also
16551650
supported. If PIL version 3.4.0 or newer is installed, "box" and
16561651
"hamming" are also supported. By default, "nearest" is used.
1652+
dtype: Dtype to use for generated arrays.
16571653
"""
16581654

16591655
def __init__(self, directory, image_data_generator,
16601656
target_size=(256, 256), color_mode='rgb',
16611657
classes=None, class_mode='categorical',
16621658
batch_size=32, shuffle=True, seed=None,
1663-
data_format=None,
1659+
data_format='channels_last',
16641660
save_to_dir=None, save_prefix='', save_format='png',
16651661
follow_links=False,
16661662
subset=None,
1667-
interpolation='nearest'):
1668-
if data_format is None:
1669-
data_format = backend.image_data_format()
1663+
interpolation='nearest',
1664+
dtype='float32'):
16701665
self.directory = directory
16711666
self.image_data_generator = image_data_generator
16721667
self.target_size = tuple(target_size)
@@ -1702,6 +1697,7 @@ def __init__(self, directory, image_data_generator,
17021697
self.save_prefix = save_prefix
17031698
self.save_format = save_format
17041699
self.interpolation = interpolation
1700+
self.dtype = dtype
17051701

17061702
if subset is not None:
17071703
validation_split = self.image_data_generator._validation_split
@@ -1769,7 +1765,7 @@ def __init__(self, directory, image_data_generator,
17691765
def _get_batches_of_transformed_samples(self, index_array):
17701766
batch_x = np.zeros(
17711767
(len(index_array),) + self.image_shape,
1772-
dtype=backend.floatx())
1768+
dtype=self.dtype)
17731769
# build batch of image data
17741770
for i, j in enumerate(index_array):
17751771
fname = self.filenames[j]
@@ -1802,11 +1798,11 @@ def _get_batches_of_transformed_samples(self, index_array):
18021798
elif self.class_mode == 'sparse':
18031799
batch_y = self.classes[index_array]
18041800
elif self.class_mode == 'binary':
1805-
batch_y = self.classes[index_array].astype(backend.floatx())
1801+
batch_y = self.classes[index_array].astype(self.dtype)
18061802
elif self.class_mode == 'categorical':
18071803
batch_y = np.zeros(
18081804
(len(batch_x), self.num_classes),
1809-
dtype=backend.floatx())
1805+
dtype=self.dtype)
18101806
for i, label in enumerate(self.classes[index_array]):
18111807
batch_y[i, label] = 1.
18121808
else:

keras_preprocessing/sequence.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010
import json
1111
from six.moves import range
1212

13-
from . import get_keras_submodule
14-
15-
keras_utils = get_keras_submodule('utils')
16-
1713

1814
def pad_sequences(sequences, maxlen=None, dtype='int32',
1915
padding='pre', truncating='pre', value=0.):
@@ -251,7 +247,7 @@ def _remove_long_seq(maxlen, seq, label):
251247
return new_seq, new_label
252248

253249

254-
class TimeseriesGenerator(keras_utils.Sequence):
250+
class TimeseriesGenerator(object):
255251
"""Utility class for generating batches of temporal data.
256252
257253
This class takes in a sequence of data-points gathered at
@@ -325,7 +321,8 @@ def __init__(self, data, targets, length,
325321

326322
if len(data) != len(targets):
327323
raise ValueError('Data and targets have to be' +
328-
' of same length. Data length is {}'.format(len(data)) +
324+
' of same length. '
325+
'Data length is {}'.format(len(data)) +
329326
' while target length is {}'.format(len(targets)))
330327

331328
self.data = data

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
'''
2222

2323
setup(name='Keras_Preprocessing',
24-
version='1.0.2',
24+
version='1.0.3',
2525
description='Easy data preprocessing and data augmentation '
2626
'for deep learning models',
2727
long_description=long_description,
2828
author='Keras Team',
2929
url='https://github.com/keras-team/keras-preprocessing',
3030
download_url='https://github.com/keras-team/'
31-
'keras-preprocessing/tarball/1.0.2',
31+
'keras-preprocessing/tarball/1.0.3',
3232
license='MIT',
3333
install_requires=['keras>=2.1.6',
3434
'numpy>=1.9.1',

0 commit comments

Comments
 (0)