Skip to content

Commit 110b9c3

Browse files
update PR
1 parent ea37ac5 commit 110b9c3

File tree

5 files changed

+129
-402
lines changed

5 files changed

+129
-402
lines changed

keras_hub/src/models/gemma/gemma_causal_lm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,6 @@ def next(prompt, cache, index):
258258
cache_update_index = index - 1
259259
batch_size = ops.shape(prompt)[0]
260260
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
261-
if keras.config.backend() == "openvino":
262-
# Avoid returning dynamic shape by openvino slice
263-
prompt = ops.reshape(prompt, [batch_size, 1])
264261
logits, hidden_states, cache = self.call_with_cache(
265262
prompt,
266263
cache,

keras_hub/src/models/gpt2/gpt2_causal_lm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,6 @@ def next(prompt, cache, index):
246246
cache_update_index = index - 1
247247
batch_size = ops.shape(prompt)[0]
248248
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
249-
if keras.config.backend() == "openvino":
250-
# Avoid returning dynamic shape by openvino slice
251-
prompt = ops.reshape(prompt, [batch_size, 1])
252249
logits, hidden_states, cache = self.call_with_cache(
253250
prompt,
254251
cache,

keras_hub/src/models/mistral/mistral_causal_lm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,6 @@ def next(prompt, cache, index):
145145
cache_update_index = index - 1
146146
batch_size = ops.shape(prompt)[0]
147147
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
148-
if keras.config.backend() == "openvino":
149-
# Avoid returning dynamic shape by openvino slice
150-
prompt = ops.reshape(prompt, [batch_size, 1])
151148
logits, hidden_states, cache = self.call_with_cache(
152149
prompt,
153150
cache,

keras_hub/src/utils/openvino_utils.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import functools
33
from pathlib import Path
44

5-
import keras
65
from keras import tree
76

87
from keras_hub.src.utils.keras_utils import print_msg
@@ -44,10 +43,9 @@ def setup_openvino_test_config(config_file_path):
4443
Returns:
4544
list: Supported paths (whitelist) for OpenVINO testing.
4645
"""
47-
supported_paths = load_openvino_supported_tools(
46+
return load_openvino_supported_tools(
4847
Path(config_file_path) / "openvino_supported_tests.txt"
4948
)
50-
return supported_paths
5149

5250

5351
@functools.lru_cache(maxsize=256)
@@ -147,7 +145,7 @@ def should_auto_skip_training_test(item):
147145
Returns:
148146
bool: True if should skip, False otherwise.
149147
"""
150-
if not item.fspath.basename.endswith(".py"):
148+
if not str(item.fspath).endswith(".py"):
151149
return False
152150
test_name = item.name.split("[")[0]
153151
return _contains_training_methods(str(item.fspath), test_name)
@@ -166,9 +164,6 @@ def get_openvino_skip_reason(item, supported_paths, auto_skip_training=True):
166164
Returns:
167165
str or None: Skip reason if should skip, None otherwise.
168166
"""
169-
if keras.config.backend() != "openvino":
170-
return None
171-
172167
test_name = item.name.split("[")[0]
173168
test_path = str(item.fspath)
174169

@@ -187,35 +182,30 @@ def get_openvino_skip_reason(item, supported_paths, auto_skip_training=True):
187182

188183
# Priority 3: Whitelist-based approach - only test supported paths
189184
if supported_paths:
190-
# Check if this test file/directory is in the whitelist
191-
# Convert test path to relative path format for comparison
192-
test_path_parts = test_path.replace("\\", "/").split("/")
193-
# Find keras_hub index and create relative path
185+
parts = test_path.replace("\\", "/").split("/")
194186
try:
195-
keras_hub_idx = test_path_parts.index("keras_hub")
196-
relative_test_path = "/".join(test_path_parts[keras_hub_idx:])
187+
keras_hub_idx = parts.index("keras_hub")
188+
relative_test_path = "/".join(parts[keras_hub_idx:])
197189
except ValueError:
198-
relative_test_path = test_path
190+
relative_test_path = test_path # fall back to absolute
199191

200192
for supported_path in supported_paths:
201-
# Exact match or directory prefix match
202193
if (
203194
relative_test_path == supported_path
204195
or relative_test_path.startswith(supported_path + "/")
205-
or relative_test_path.startswith(supported_path + "\\")
206196
):
207-
return None # Found in whitelist, don't skip
197+
return None # in whitelist
208198

209-
# Not found in whitelist, skip it
210199
return "File/directory not in OpenVINO whitelist"
200+
211201
return None
212202

213203

214204
def get_device():
215205
"""Detect and return the best available OpenVINO device.
216206
217207
Returns:
218-
tuple: (core, device) where device is "GPU" or "CPU".
208+
str: "GPU" if available, otherwise "CPU".
219209
"""
220210
return "GPU" if "GPU" in core.available_devices else "CPU"
221211

@@ -232,16 +222,17 @@ def compile_model(struct_params, struct_outputs, device, model_dtype):
232222
Returns:
233223
Compiled OpenVINO model ready for inference.
234224
"""
235-
parameters = [p.output.get_node() for p in tree.flatten(struct_params)]
236-
results = [ov_opset.result(r.output) for r in tree.flatten(struct_outputs)]
225+
flat_params = tree.flatten(struct_params)
226+
flat_outputs = tree.flatten(struct_outputs)
227+
parameters = [p.output.get_node() for p in flat_params]
228+
results = [ov_opset.result(r.output) for r in flat_outputs]
237229
ov_model = ov.Model(results=results, parameters=parameters)
238230
for ov_input in ov_model.inputs:
239231
rank = ov_input.get_partial_shape().rank.get_length()
240232
ov_input.get_node().set_partial_shape(ov.PartialShape([-1] * rank))
241233
ov_model.validate_nodes_and_infer_types()
242234
config = {"INFERENCE_PRECISION_HINT": model_dtype}
243-
compiled_model = core.compile_model(ov_model, device, config)
244-
return compiled_model
235+
return core.compile_model(ov_model, device, config)
245236

246237

247238
def get_outputs(inputs, struct_outputs, compiled_ov_model, unpack_singleton):
@@ -257,9 +248,9 @@ def get_outputs(inputs, struct_outputs, compiled_ov_model, unpack_singleton):
257248
Structured model outputs matching expected format.
258249
"""
259250
flatten_inputs = tree.flatten(inputs)
260-
outputs = compiled_ov_model(flatten_inputs).to_tuple()
261-
outputs = unpack_singleton(tree.pack_sequence_as(struct_outputs, outputs))
262-
return outputs
251+
raw = compiled_ov_model(flatten_inputs).to_tuple()
252+
packed = tree.pack_sequence_as(struct_outputs, raw)
253+
return unpack_singleton(packed)
263254

264255

265256
def ov_infer(model, inputs, stop_token_ids, fn):
@@ -280,7 +271,7 @@ def ov_infer(model, inputs, stop_token_ids, fn):
280271
"""
281272
device = get_device()
282273

283-
# Try to use existing compiled model
274+
# Try to use existing compiled model for the same device
284275
if (
285276
getattr(model, "ov_compiled_model", None) is not None
286277
and getattr(model, "ov_device", None) is not None
@@ -298,8 +289,8 @@ def ov_infer(model, inputs, stop_token_ids, fn):
298289
"WARNING: OpenVINO inference \033[1mFAILED\033[0m, "
299290
"recompiling model and trying again.\n" + str(e)
300291
)
301-
del model.ov_compiled_model
302-
del model.struct_outputs
292+
model.ov_compiled_model = None
293+
model.struct_outputs = None
303294

304295
# Compile a new model
305296
struct_params = model._parameterize_data(inputs)
@@ -309,7 +300,6 @@ def ov_infer(model, inputs, stop_token_ids, fn):
309300
model.ov_compiled_model = compile_model(
310301
struct_params, model.struct_outputs, device, model_dtype
311302
)
312-
313303
return get_outputs(
314304
inputs,
315305
model.struct_outputs,

0 commit comments

Comments
 (0)