Skip to content
This repository was archived by the owner on Mar 3, 2024. It is now read-only.

Commit f8bb7ab

Browse files
committed
Add workaround to fix keras loading
1 parent 9ee4726 commit f8bb7ab

File tree

4 files changed

+32
-27
lines changed

4 files changed

+32
-27
lines changed

keras_bert/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
from .util import *
66
from .datasets import *
77

8-
__version__ = '0.87.0'
8+
__version__ = '0.88.0'

keras_bert/datasets/pretrained.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import shutil
66
from collections import namedtuple
77
from keras_bert.backend import keras
8+
from tensorflow.keras.utils import get_file
89

910
__all__ = ['PretrainedInfo', 'PretrainedList', 'get_pretrained']
1011

@@ -35,7 +36,7 @@ def get_pretrained(info):
3536
path = info
3637
if isinstance(info, PretrainedInfo):
3738
path = info.url
38-
path = keras.utils.get_file(fname=os.path.split(path)[-1], origin=path, extract=True)
39+
path = get_file(fname=os.path.split(path)[-1], origin=path, extract=True)
3940
base_part, file_part = os.path.split(path)
4041
file_part = file_part.split('.')[0]
4142
if isinstance(info, PretrainedInfo):

tests/optimizers/test_warmup.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
import tempfile
33
from unittest import TestCase
4+
45
import numpy as np
5-
from keras_bert.backend import keras, TF_KERAS
6+
7+
from keras_bert.backend import keras
68
from keras_bert import AdamWarmup
79

810

@@ -31,7 +33,10 @@ def _test_fit(self, optmizer):
3133

3234
model_path = os.path.join(tempfile.gettempdir(), 'keras_warmup_%f.h5' % np.random.random())
3335
model.save(model_path)
34-
model = keras.models.load_model(model_path, custom_objects={'AdamWarmup': AdamWarmup})
36+
37+
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
38+
with CustomObjectScope({'AdamWarmup': AdamWarmup}): # Workaround for incorrect global variable used in keras
39+
model = keras.models.load_model(model_path, custom_objects={'AdamWarmup': AdamWarmup})
3540

3641
results = model.predict(x).argmax(axis=-1)
3742
diff = np.sum(np.abs(y - results))
@@ -78,17 +83,10 @@ def test_fit_embed(self):
7883

7984
x = np.random.randint(0, 5, (1024, 15))
8085
y = (x[:, 1] > 2).astype('int32')
81-
model.fit(x, y, epochs=10)
86+
model.fit(x, y, epochs=10, verbose=1)
8287

8388
model_path = os.path.join(tempfile.gettempdir(), 'test_warmup_%f.h5' % np.random.random())
8489
model.save(model_path)
85-
keras.models.load_model(model_path, custom_objects={'AdamWarmup': AdamWarmup})
86-
87-
def test_legacy(self):
88-
opt = AdamWarmup(
89-
decay_steps=10000,
90-
warmup_steps=5000,
91-
learning_rate=1e-3,
92-
)
93-
if not TF_KERAS:
94-
opt.lr = opt.lr
90+
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
91+
with CustomObjectScope({'AdamWarmup': AdamWarmup}): # Workaround for incorrect global variable used in keras
92+
keras.models.load_model(model_path, custom_objects={'AdamWarmup': AdamWarmup})

tests/test_bert.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ def test_sample(self):
1818
)
1919
model_path = os.path.join(tempfile.gettempdir(), 'keras_bert_%f.h5' % np.random.random())
2020
model.save(model_path)
21-
model = keras.models.load_model(
22-
model_path,
23-
custom_objects=get_custom_objects(),
24-
)
21+
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
22+
with CustomObjectScope(get_custom_objects()): # Workaround for incorrect global variable used in keras
23+
model = keras.models.load_model(
24+
model_path,
25+
custom_objects=get_custom_objects(),
26+
)
2527
model.summary(line_length=200)
2628

2729
def test_task_embed(self):
@@ -38,10 +40,12 @@ def test_task_embed(self):
3840
model = keras.models.Model(inputs, outputs)
3941
model_path = os.path.join(tempfile.gettempdir(), 'keras_bert_%f.h5' % np.random.random())
4042
model.save(model_path)
41-
model = keras.models.load_model(
42-
model_path,
43-
custom_objects=get_custom_objects(),
44-
)
43+
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
44+
with CustomObjectScope(get_custom_objects()): # Workaround for incorrect global variable used in keras
45+
model = keras.models.load_model(
46+
model_path,
47+
custom_objects=get_custom_objects(),
48+
)
4549
model.summary(line_length=200)
4650

4751
def test_save_load_json(self):
@@ -82,10 +86,12 @@ def test_fit(self):
8286
token_list = list(token_dict.keys())
8387
if os.path.exists(model_path):
8488
steps_per_epoch = 10
85-
model = keras.models.load_model(
86-
model_path,
87-
custom_objects=get_custom_objects(),
88-
)
89+
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
90+
with CustomObjectScope(get_custom_objects()): # Workaround for incorrect global variable used in keras
91+
model = keras.models.load_model(
92+
model_path,
93+
custom_objects=get_custom_objects(),
94+
)
8995
else:
9096
steps_per_epoch = 1000
9197
model = get_model(

0 commit comments

Comments
 (0)