Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit e714b6b

Browse files
committed
Gracefully handle missing scipy module.
1 parent e002ebd commit e714b6b

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

keras_preprocessing/image.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
import numpy as np
88
import re
9-
from scipy import linalg
10-
import scipy.ndimage as ndi
119
from six.moves import range
1210
import os
1311
import threading
@@ -25,6 +23,12 @@
2523
from PIL import Image as pil_image
2624
except ImportError:
2725
pil_image = None
26+
ImageEnhance = None
27+
28+
try:
29+
import scipy
30+
except ImportError:
31+
scipy = None
2832

2933
if pil_image is not None:
3034
_PIL_INTERPOLATION_METHODS = {
@@ -208,6 +212,9 @@ def apply_brightness_shift(x, brightness):
208212
# Raises
209213
ValueError if `brightness_range` isn't a tuple.
210214
"""
215+
if ImageEnhance is None:
216+
raise ImportError('Using brightness shifts requires PIL. '
217+
'Install PIL or Pillow.')
211218
x = array_to_img(x)
212219
x = imgenhancer_Brightness = ImageEnhance.Brightness(x)
213220
x = imgenhancer_Brightness.enhance(brightness)
@@ -272,6 +279,9 @@ def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1,
272279
# Returns
273280
The transformed version of the input.
274281
"""
282+
if scipy is None:
283+
raise ImportError('Image transformations require SciPy. '
284+
'Install SciPy.')
275285
transform_matrix = None
276286
if theta != 0:
277287
theta = np.deg2rad(theta)
@@ -316,7 +326,7 @@ def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1,
316326
final_affine_matrix = transform_matrix[:2, :2]
317327
final_offset = transform_matrix[:2, 2]
318328

319-
channel_images = [ndi.interpolation.affine_transform(
329+
channel_images = [scipy.ndimage.interpolation.affine_transform(
320330
x_channel,
321331
final_affine_matrix,
322332
final_offset,
@@ -1230,10 +1240,13 @@ def fit(self, x,
12301240
x /= (self.std + backend.epsilon())
12311241

12321242
if self.zca_whitening:
1243+
if scipy is None:
1244+
raise ImportError('Using zca_whitening requires SciPy. '
1245+
'Install SciPy.')
12331246
flat_x = np.reshape(
12341247
x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]))
12351248
sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0]
1236-
u, s, _ = linalg.svd(sigma)
1249+
u, s, _ = scipy.linalg.svd(sigma)
12371250
s_inv = 1. / np.sqrt(s[np.newaxis] + self.zca_epsilon)
12381251
self.principal_components = (u * s_inv).dot(u.T)
12391252

0 commit comments

Comments
 (0)