Skip to content

Commit 8baea81

Browse files
update causal.lm
1 parent 792273e commit 8baea81

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

keras_hub/src/models/causal_lm.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,14 @@ def make_generate_function(self):
133133

134134
self.generate_function = self.generate_step
135135
if keras.config.backend() == "openvino":
136+
import os
137+
138+
os.environ["OV_ENABLE_EINSUM_DECOMPOSITION"] = "1"
136139
import openvino as ov
137140
import openvino.runtime.opset14 as ov_opset
138141

142+
from keras_hub.src.utils.keras_utils import print_msg
143+
139144
def ov_infer(inputs, stop_token_ids, fn):
140145
def get_outputs(inputs, struct_outputs, compiled_ov_model):
141146
flatten_inputs = tree.flatten(inputs)
@@ -147,28 +152,33 @@ def get_outputs(inputs, struct_outputs, compiled_ov_model):
147152
)
148153
return outputs
149154

150-
struct_params = self._parameterize_data(inputs)
151-
struct_outputs = fn(struct_params, stop_token_ids)
152-
153155
# Try using the existing compiled model
154156
if self.ov_compiled_model is not None:
155157
try:
156158
return get_outputs(
157-
inputs, struct_outputs, self.ov_compiled_model
159+
inputs, self.struct_outputs, self.ov_compiled_model
158160
)
159-
except Exception:
160-
# Delete previous model, then
161+
except RuntimeError as e:
162+
# Delete previous model and struct outputs, then
161163
# Fall through to recompilation if inference fails
164+
print_msg(
165+
"WARNING: OpenVINO inference \033[1mFAILED\033[0m, "
166+
"so we'll Rebuild and compile the model then "
167+
f"try again.\n{e}"
168+
)
162169
del self.ov_compiled_model
170+
del self.struct_outputs
163171
pass
164172

165173
# Rebuild and compile the OpenVINO model
174+
struct_params = self._parameterize_data(inputs)
175+
self.struct_outputs = fn(struct_params, stop_token_ids)
166176
parameters = [
167177
p.output.get_node() for p in tree.flatten(struct_params)
168178
]
169179
results = [
170180
ov_opset.result(r.output)
171-
for r in tree.flatten(struct_outputs)
181+
for r in tree.flatten(self.struct_outputs)
172182
]
173183
ov_model = ov.Model(results=results, parameters=parameters)
174184
for ov_input in ov_model.inputs:
@@ -182,10 +192,11 @@ def get_outputs(inputs, struct_outputs, compiled_ov_model):
182192
# OpenVINO supports only compiling with 'CPU' devices.
183193
self.ov_compiled_model = core.compile_model(ov_model, device)
184194
return get_outputs(
185-
inputs, struct_outputs, self.ov_compiled_model
195+
inputs, self.struct_outputs, self.ov_compiled_model
186196
)
187197

188198
def wrapped_generate_function(inputs, stop_token_ids=None):
199+
# ops.array converts yo numpy in openvino backend
189200
inputs = tree.map_structure(ops.array, inputs)
190201
return ov_infer(inputs, stop_token_ids, self.generate_step)
191202

0 commit comments

Comments
 (0)