Skip to content

Commit 553c3d7

Browse files
add openvino_utils
1 parent 954215d commit 553c3d7

File tree

3 files changed

+408
-71
lines changed

3 files changed

+408
-71
lines changed

keras_hub/src/models/causal_lm.py

Lines changed: 5 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -133,80 +133,14 @@ def make_generate_function(self):
133133

134134
self.generate_function = self.generate_step
135135
if keras.config.backend() == "openvino":
136-
import openvino as ov
137-
import openvino.runtime.opset14 as ov_opset
138-
139-
from keras_hub.src.utils.keras_utils import print_msg
140-
141-
def ov_infer(inputs, stop_token_ids, fn):
142-
def get_outputs(inputs, struct_outputs, compiled_ov_model):
143-
flatten_inputs = tree.flatten(inputs)
144-
outputs = compiled_ov_model(flatten_inputs).to_tuple()
145-
outputs = self._unpack_singleton(
146-
tree.pack_sequence_as(struct_outputs, outputs)
147-
)
148-
return outputs
149-
150-
core = ov.Core()
151-
device = "GPU" if "GPU" in core.available_devices else "CPU"
152-
153-
# Try using the existing compiled model
154-
if (
155-
self.ov_compiled_model is not None
156-
and getattr(self, "ov_device", None) is not None
157-
and device == self.ov_device
158-
):
159-
try:
160-
return get_outputs(
161-
inputs, self.struct_outputs, self.ov_compiled_model
162-
)
163-
except RuntimeError as e:
164-
# Delete previous model and struct outputs, then
165-
# Fall through to recompilation if inference fails
166-
print_msg(
167-
"WARNING: OpenVINO inference \033[1mFAILED\033[0m, "
168-
"so we'll Rebuild and compile the model then "
169-
f"try again.\n{e}"
170-
)
171-
del self.ov_compiled_model
172-
del self.struct_outputs
173-
174-
# Rebuild and compile the OpenVINO model
175-
struct_params = self._parameterize_data(inputs)
176-
self.struct_outputs = fn(struct_params, stop_token_ids)
177-
parameters = [
178-
p.output.get_node() for p in tree.flatten(struct_params)
179-
]
180-
results = [
181-
ov_opset.result(r.output)
182-
for r in tree.flatten(self.struct_outputs)
183-
]
184-
ov_model = ov.Model(results=results, parameters=parameters)
185-
for ov_input in ov_model.inputs:
186-
rank = ov_input.get_partial_shape().rank.get_length()
187-
ov_input.get_node().set_partial_shape(
188-
ov.PartialShape([-1] * rank)
189-
)
190-
ov_model.validate_nodes_and_infer_types()
191-
192-
self.ov_device = device
193-
model_dtype = (
194-
"f16"
195-
if self.dtype == "float16" or self.dtype == "bfloat16"
196-
else "f32"
197-
)
198-
config = {"INFERENCE_PRECISION_HINT": model_dtype}
199-
self.ov_compiled_model = core.compile_model(
200-
ov_model, device, config
201-
)
202-
return get_outputs(
203-
inputs, self.struct_outputs, self.ov_compiled_model
204-
)
136+
from keras_hub.src.utils.openvino_utils import ov_infer
205137

206138
def wrapped_generate_function(inputs, stop_token_ids=None):
207-
# ops.array converts to numpy in openvino backend
139+
# Convert to numpy for OpenVINO backend
208140
inputs = tree.map_structure(ops.array, inputs)
209-
return ov_infer(inputs, stop_token_ids, self.generate_step)
141+
return ov_infer(
142+
self, inputs, stop_token_ids, self.generate_step
143+
)
210144

211145
self.generate_function = wrapped_generate_function
212146
if keras.config.backend() == "torch":

keras_hub/src/utils/openvino_utils.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from keras import tree
2+
3+
from keras_hub.src.utils.keras_utils import print_msg
4+
5+
try:
6+
import openvino as ov
7+
import openvino.opset14 as ov_opset
8+
from openvino import Core
9+
10+
core = Core()
11+
except ImportError:
12+
ov = None
13+
ov_opset = None
14+
core = None
15+
16+
17+
def get_device():
18+
"""Detect and return the best available OpenVINO device.
19+
20+
Returns:
21+
tuple: (core, device) where device is "GPU" or "CPU".
22+
"""
23+
return "GPU" if "GPU" in core.available_devices else "CPU"
24+
25+
26+
def compile_model(struct_params, struct_outputs, device, model_dtype):
27+
"""Compile OpenVINO model with dynamic shapes and precision hints.
28+
29+
Args:
30+
struct_params: Model parameters structure.
31+
struct_outputs: Model outputs structure.
32+
device: Target device ("GPU" or "CPU").
33+
model_dtype: Model precision ("f16" or "f32").
34+
35+
Returns:
36+
Compiled OpenVINO model ready for inference.
37+
"""
38+
parameters = [p.output.get_node() for p in tree.flatten(struct_params)]
39+
results = [ov_opset.result(r.output) for r in tree.flatten(struct_outputs)]
40+
ov_model = ov.Model(results=results, parameters=parameters)
41+
42+
# Set dynamic shape
43+
for ov_input in ov_model.inputs:
44+
rank = ov_input.get_partial_shape().rank.get_length()
45+
ov_input.get_node().set_partial_shape(ov.PartialShape([-1] * rank))
46+
47+
ov_model.validate_nodes_and_infer_types()
48+
49+
config = {"INFERENCE_PRECISION_HINT": model_dtype}
50+
compiled_model = core.compile_model(ov_model, device, config)
51+
return compiled_model
52+
53+
54+
def get_outputs(inputs, struct_outputs, compiled_ov_model, unpack_singleton):
55+
"""Execute compiled OpenVINO model and return structured outputs.
56+
57+
Args:
58+
inputs: Input tensors for inference.
59+
struct_outputs: Expected output structure.
60+
compiled_ov_model: Compiled OpenVINO model.
61+
unpack_singleton: Function to unpack singleton outputs.
62+
63+
Returns:
64+
Structured model outputs matching expected format.
65+
"""
66+
flatten_inputs = tree.flatten(inputs)
67+
outputs = compiled_ov_model(flatten_inputs).to_tuple()
68+
outputs = unpack_singleton(tree.pack_sequence_as(struct_outputs, outputs))
69+
return outputs
70+
71+
72+
def ov_infer(model, inputs, stop_token_ids, fn):
73+
"""High-level OpenVINO inference with model reuse and compilation.
74+
75+
This function manages OpenVINO model compilation and caching. It reuses
76+
existing compiled models when possible, or compiles new ones as needed.
77+
Handles device detection and automatic precision selection.
78+
79+
Args:
80+
model: Keras model with OpenVINO backend support.
81+
inputs: Input tensors for inference.
82+
stop_token_ids: Token IDs that should stop generation.
83+
fn: Function to execute with the parameterized inputs.
84+
85+
Returns:
86+
Model outputs from OpenVINO inference.
87+
"""
88+
device = get_device()
89+
90+
# Try to use existing compiled model
91+
if (
92+
getattr(model, "ov_compiled_model", None) is not None
93+
and getattr(model, "ov_device", None) is not None
94+
and device == model.ov_device
95+
):
96+
try:
97+
return get_outputs(
98+
inputs,
99+
model.struct_outputs,
100+
model.ov_compiled_model,
101+
model._unpack_singleton,
102+
)
103+
except RuntimeError as e:
104+
print_msg(
105+
"WARNING: OpenVINO inference \033[1mFAILED\033[0m, "
106+
"recompiling model and trying again.\n" + str(e)
107+
)
108+
del model.ov_compiled_model
109+
del model.struct_outputs
110+
111+
# Compile a new model
112+
struct_params = model._parameterize_data(inputs)
113+
model.struct_outputs = fn(struct_params, stop_token_ids)
114+
model.ov_device = device
115+
model_dtype = "f16" if model.dtype in ("float16", "bfloat16") else "f32"
116+
117+
model.ov_compiled_model = compile_model(
118+
struct_params, model.struct_outputs, device, model_dtype
119+
)
120+
121+
return get_outputs(
122+
inputs,
123+
model.struct_outputs,
124+
model.ov_compiled_model,
125+
model._unpack_singleton,
126+
)

0 commit comments

Comments
 (0)