Skip to content

Commit bc0afe5

Browse files
update causal.lm
1 parent 792273e commit bc0afe5

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

keras_hub/src/models/causal_lm.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,14 @@ def make_generate_function(self):
133133

134134
self.generate_function = self.generate_step
135135
if keras.config.backend() == "openvino":
136+
import os
137+
138+
os.environ["OV_ENABLE_EINSUM_DECOMPOSITION"] = "1"
136139
import openvino as ov
137140
import openvino.runtime.opset14 as ov_opset
138141

142+
from keras_hub.src.utils.keras_utils import print_msg
143+
139144
def ov_infer(inputs, stop_token_ids, fn):
140145
def get_outputs(inputs, struct_outputs, compiled_ov_model):
141146
flatten_inputs = tree.flatten(inputs)
@@ -147,28 +152,32 @@ def get_outputs(inputs, struct_outputs, compiled_ov_model):
147152
)
148153
return outputs
149154

150-
struct_params = self._parameterize_data(inputs)
151-
struct_outputs = fn(struct_params, stop_token_ids)
152-
153155
# Try using the existing compiled model
154156
if self.ov_compiled_model is not None:
155157
try:
156158
return get_outputs(
157-
inputs, struct_outputs, self.ov_compiled_model
159+
inputs, self.struct_outputs, self.ov_compiled_model
158160
)
159-
except Exception:
160-
# Delete previous model, then
161+
except RuntimeError as e:
162+
# Delete previous model and struct outputs, then
161163
# Fall through to recompilation if inference fails
164+
print_msg(
165+
"WARNING: OpenVINO inference \033[1mFAILED\033[0m, "
166+
f"so we'll recompile the model and try again.\n{e}"
167+
)
162168
del self.ov_compiled_model
169+
del self.struct_outputs
163170
pass
164171

165172
# Rebuild and compile the OpenVINO model
173+
struct_params = self._parameterize_data(inputs)
174+
self.struct_outputs = fn(struct_params, stop_token_ids)
166175
parameters = [
167176
p.output.get_node() for p in tree.flatten(struct_params)
168177
]
169178
results = [
170179
ov_opset.result(r.output)
171-
for r in tree.flatten(struct_outputs)
180+
for r in tree.flatten(self.struct_outputs)
172181
]
173182
ov_model = ov.Model(results=results, parameters=parameters)
174183
for ov_input in ov_model.inputs:
@@ -182,10 +191,11 @@ def get_outputs(inputs, struct_outputs, compiled_ov_model):
182191
# OpenVINO supports only compiling with 'CPU' devices.
183192
self.ov_compiled_model = core.compile_model(ov_model, device)
184193
return get_outputs(
185-
inputs, struct_outputs, self.ov_compiled_model
194+
inputs, self.struct_outputs, self.ov_compiled_model
186195
)
187196

188197
def wrapped_generate_function(inputs, stop_token_ids=None):
198+
# ops.array converts yo numpy in openvino backend
189199
inputs = tree.map_structure(ops.array, inputs)
190200
return ov_infer(inputs, stop_token_ids, self.generate_step)
191201

0 commit comments

Comments
 (0)