2
2
import functools
3
3
from pathlib import Path
4
4
5
- import keras
6
5
from keras import tree
7
6
8
7
from keras_hub .src .utils .keras_utils import print_msg
@@ -44,10 +43,9 @@ def setup_openvino_test_config(config_file_path):
44
43
Returns:
45
44
list: Supported paths (whitelist) for OpenVINO testing.
46
45
"""
47
- supported_paths = load_openvino_supported_tools (
46
+ return load_openvino_supported_tools (
48
47
Path (config_file_path ) / "openvino_supported_tests.txt"
49
48
)
50
- return supported_paths
51
49
52
50
53
51
@functools .lru_cache (maxsize = 256 )
@@ -147,7 +145,7 @@ def should_auto_skip_training_test(item):
147
145
Returns:
148
146
bool: True if should skip, False otherwise.
149
147
"""
150
- if not item .fspath . basename .endswith (".py" ):
148
+ if not str ( item .fspath ) .endswith (".py" ):
151
149
return False
152
150
test_name = item .name .split ("[" )[0 ]
153
151
return _contains_training_methods (str (item .fspath ), test_name )
@@ -166,9 +164,6 @@ def get_openvino_skip_reason(item, supported_paths, auto_skip_training=True):
166
164
Returns:
167
165
str or None: Skip reason if should skip, None otherwise.
168
166
"""
169
- if keras .config .backend () != "openvino" :
170
- return None
171
-
172
167
test_name = item .name .split ("[" )[0 ]
173
168
test_path = str (item .fspath )
174
169
@@ -187,35 +182,30 @@ def get_openvino_skip_reason(item, supported_paths, auto_skip_training=True):
187
182
188
183
# Priority 3: Whitelist-based approach - only test supported paths
189
184
if supported_paths :
190
- # Check if this test file/directory is in the whitelist
191
- # Convert test path to relative path format for comparison
192
- test_path_parts = test_path .replace ("\\ " , "/" ).split ("/" )
193
- # Find keras_hub index and create relative path
185
+ parts = test_path .replace ("\\ " , "/" ).split ("/" )
194
186
try :
195
- keras_hub_idx = test_path_parts .index ("keras_hub" )
196
- relative_test_path = "/" .join (test_path_parts [keras_hub_idx :])
187
+ keras_hub_idx = parts .index ("keras_hub" )
188
+ relative_test_path = "/" .join (parts [keras_hub_idx :])
197
189
except ValueError :
198
- relative_test_path = test_path
190
+ relative_test_path = test_path # fall back to absolute
199
191
200
192
for supported_path in supported_paths :
201
- # Exact match or directory prefix match
202
193
if (
203
194
relative_test_path == supported_path
204
195
or relative_test_path .startswith (supported_path + "/" )
205
- or relative_test_path .startswith (supported_path + "\\ " )
206
196
):
207
- return None # Found in whitelist, don't skip
197
+ return None # in whitelist
208
198
209
- # Not found in whitelist, skip it
210
199
return "File/directory not in OpenVINO whitelist"
200
+
211
201
return None
212
202
213
203
214
204
def get_device ():
215
205
"""Detect and return the best available OpenVINO device.
216
206
217
207
Returns:
218
- tuple: (core, device) where device is "GPU" or "CPU".
208
+ str: "GPU" if available, otherwise "CPU".
219
209
"""
220
210
return "GPU" if "GPU" in core .available_devices else "CPU"
221
211
@@ -232,16 +222,17 @@ def compile_model(struct_params, struct_outputs, device, model_dtype):
232
222
Returns:
233
223
Compiled OpenVINO model ready for inference.
234
224
"""
235
- parameters = [p .output .get_node () for p in tree .flatten (struct_params )]
236
- results = [ov_opset .result (r .output ) for r in tree .flatten (struct_outputs )]
225
+ flat_params = tree .flatten (struct_params )
226
+ flat_outputs = tree .flatten (struct_outputs )
227
+ parameters = [p .output .get_node () for p in flat_params ]
228
+ results = [ov_opset .result (r .output ) for r in flat_outputs ]
237
229
ov_model = ov .Model (results = results , parameters = parameters )
238
230
for ov_input in ov_model .inputs :
239
231
rank = ov_input .get_partial_shape ().rank .get_length ()
240
232
ov_input .get_node ().set_partial_shape (ov .PartialShape ([- 1 ] * rank ))
241
233
ov_model .validate_nodes_and_infer_types ()
242
234
config = {"INFERENCE_PRECISION_HINT" : model_dtype }
243
- compiled_model = core .compile_model (ov_model , device , config )
244
- return compiled_model
235
+ return core .compile_model (ov_model , device , config )
245
236
246
237
247
238
def get_outputs (inputs , struct_outputs , compiled_ov_model , unpack_singleton ):
@@ -257,9 +248,9 @@ def get_outputs(inputs, struct_outputs, compiled_ov_model, unpack_singleton):
257
248
Structured model outputs matching expected format.
258
249
"""
259
250
flatten_inputs = tree .flatten (inputs )
260
- outputs = compiled_ov_model (flatten_inputs ).to_tuple ()
261
- outputs = unpack_singleton ( tree .pack_sequence_as (struct_outputs , outputs ) )
262
- return outputs
251
+ raw = compiled_ov_model (flatten_inputs ).to_tuple ()
252
+ packed = tree .pack_sequence_as (struct_outputs , raw )
253
+ return unpack_singleton ( packed )
263
254
264
255
265
256
def ov_infer (model , inputs , stop_token_ids , fn ):
@@ -280,7 +271,7 @@ def ov_infer(model, inputs, stop_token_ids, fn):
280
271
"""
281
272
device = get_device ()
282
273
283
- # Try to use existing compiled model
274
+ # Try to use existing compiled model for the same device
284
275
if (
285
276
getattr (model , "ov_compiled_model" , None ) is not None
286
277
and getattr (model , "ov_device" , None ) is not None
@@ -298,8 +289,8 @@ def ov_infer(model, inputs, stop_token_ids, fn):
298
289
"WARNING: OpenVINO inference \033 [1mFAILED\033 [0m, "
299
290
"recompiling model and trying again.\n " + str (e )
300
291
)
301
- del model .ov_compiled_model
302
- del model .struct_outputs
292
+ model .ov_compiled_model = None
293
+ model .struct_outputs = None
303
294
304
295
# Compile a new model
305
296
struct_params = model ._parameterize_data (inputs )
@@ -309,7 +300,6 @@ def ov_infer(model, inputs, stop_token_ids, fn):
309
300
model .ov_compiled_model = compile_model (
310
301
struct_params , model .struct_outputs , device , model_dtype
311
302
)
312
-
313
303
return get_outputs (
314
304
inputs ,
315
305
model .struct_outputs ,
0 commit comments