|
| 1 | +from keras import tree |
| 2 | + |
| 3 | +from keras_hub.src.utils.keras_utils import print_msg |
| 4 | + |
| 5 | +try: |
| 6 | + import openvino as ov |
| 7 | + import openvino.opset14 as ov_opset |
| 8 | + from openvino import Core |
| 9 | + |
| 10 | + core = Core() |
| 11 | +except ImportError: |
| 12 | + ov = None |
| 13 | + ov_opset = None |
| 14 | + core = None |
| 15 | + |
| 16 | + |
| 17 | +def get_device(): |
| 18 | + """Detect and return the best available OpenVINO device. |
| 19 | +
|
| 20 | + Returns: |
| 21 | + tuple: (core, device) where device is "GPU" or "CPU". |
| 22 | + """ |
| 23 | + return "GPU" if "GPU" in core.available_devices else "CPU" |
| 24 | + |
| 25 | + |
| 26 | +def compile_model(struct_params, struct_outputs, device, model_dtype): |
| 27 | + """Compile OpenVINO model with dynamic shapes and precision hints. |
| 28 | +
|
| 29 | + Args: |
| 30 | + struct_params: Model parameters structure. |
| 31 | + struct_outputs: Model outputs structure. |
| 32 | + device: Target device ("GPU" or "CPU"). |
| 33 | + model_dtype: Model precision ("f16" or "f32"). |
| 34 | +
|
| 35 | + Returns: |
| 36 | + Compiled OpenVINO model ready for inference. |
| 37 | + """ |
| 38 | + parameters = [p.output.get_node() for p in tree.flatten(struct_params)] |
| 39 | + results = [ov_opset.result(r.output) for r in tree.flatten(struct_outputs)] |
| 40 | + ov_model = ov.Model(results=results, parameters=parameters) |
| 41 | + |
| 42 | + # Set dynamic shape |
| 43 | + for ov_input in ov_model.inputs: |
| 44 | + rank = ov_input.get_partial_shape().rank.get_length() |
| 45 | + ov_input.get_node().set_partial_shape(ov.PartialShape([-1] * rank)) |
| 46 | + |
| 47 | + ov_model.validate_nodes_and_infer_types() |
| 48 | + |
| 49 | + config = {"INFERENCE_PRECISION_HINT": model_dtype} |
| 50 | + compiled_model = core.compile_model(ov_model, device, config) |
| 51 | + return compiled_model |
| 52 | + |
| 53 | + |
| 54 | +def get_outputs(inputs, struct_outputs, compiled_ov_model, unpack_singleton): |
| 55 | + """Execute compiled OpenVINO model and return structured outputs. |
| 56 | +
|
| 57 | + Args: |
| 58 | + inputs: Input tensors for inference. |
| 59 | + struct_outputs: Expected output structure. |
| 60 | + compiled_ov_model: Compiled OpenVINO model. |
| 61 | + unpack_singleton: Function to unpack singleton outputs. |
| 62 | +
|
| 63 | + Returns: |
| 64 | + Structured model outputs matching expected format. |
| 65 | + """ |
| 66 | + flatten_inputs = tree.flatten(inputs) |
| 67 | + outputs = compiled_ov_model(flatten_inputs).to_tuple() |
| 68 | + outputs = unpack_singleton(tree.pack_sequence_as(struct_outputs, outputs)) |
| 69 | + return outputs |
| 70 | + |
| 71 | + |
| 72 | +def ov_infer(model, inputs, stop_token_ids, fn): |
| 73 | + """High-level OpenVINO inference with model reuse and compilation. |
| 74 | +
|
| 75 | + This function manages OpenVINO model compilation and caching. It reuses |
| 76 | + existing compiled models when possible, or compiles new ones as needed. |
| 77 | + Handles device detection and automatic precision selection. |
| 78 | +
|
| 79 | + Args: |
| 80 | + model: Keras model with OpenVINO backend support. |
| 81 | + inputs: Input tensors for inference. |
| 82 | + stop_token_ids: Token IDs that should stop generation. |
| 83 | + fn: Function to execute with the parameterized inputs. |
| 84 | +
|
| 85 | + Returns: |
| 86 | + Model outputs from OpenVINO inference. |
| 87 | + """ |
| 88 | + device = get_device() |
| 89 | + |
| 90 | + # Try to use existing compiled model |
| 91 | + if ( |
| 92 | + getattr(model, "ov_compiled_model", None) is not None |
| 93 | + and getattr(model, "ov_device", None) is not None |
| 94 | + and device == model.ov_device |
| 95 | + ): |
| 96 | + try: |
| 97 | + return get_outputs( |
| 98 | + inputs, |
| 99 | + model.struct_outputs, |
| 100 | + model.ov_compiled_model, |
| 101 | + model._unpack_singleton, |
| 102 | + ) |
| 103 | + except RuntimeError as e: |
| 104 | + print_msg( |
| 105 | + "WARNING: OpenVINO inference \033[1mFAILED\033[0m, " |
| 106 | + "recompiling model and trying again.\n" + str(e) |
| 107 | + ) |
| 108 | + del model.ov_compiled_model |
| 109 | + del model.struct_outputs |
| 110 | + |
| 111 | + # Compile a new model |
| 112 | + struct_params = model._parameterize_data(inputs) |
| 113 | + model.struct_outputs = fn(struct_params, stop_token_ids) |
| 114 | + model.ov_device = device |
| 115 | + model_dtype = "f16" if model.dtype in ("float16", "bfloat16") else "f32" |
| 116 | + |
| 117 | + model.ov_compiled_model = compile_model( |
| 118 | + struct_params, model.struct_outputs, device, model_dtype |
| 119 | + ) |
| 120 | + |
| 121 | + return get_outputs( |
| 122 | + inputs, |
| 123 | + model.struct_outputs, |
| 124 | + model.ov_compiled_model, |
| 125 | + model._unpack_singleton, |
| 126 | + ) |
0 commit comments