@@ -133,8 +133,13 @@ def make_generate_function(self):
133
133
134
134
self .generate_function = self .generate_step
135
135
if keras .config .backend () == "openvino" :
136
+ import os
137
+ import tempfile
138
+
136
139
import openvino as ov
137
140
import openvino .runtime .opset14 as ov_opset
141
+ from nncf import CompressWeightsMode
142
+ from nncf import compress_weights
138
143
139
144
from keras_hub .src .utils .openvino_utils import get_outputs
140
145
from keras_hub .src .utils .openvino_utils import get_struct_outputs
@@ -143,17 +148,45 @@ def ov_infer(inputs, stop_token_ids, fn):
143
148
struct_params , struct_outputs = get_struct_outputs (
144
149
inputs , stop_token_ids , fn
145
150
)
146
- parameters = [
147
- p .output .get_node () for p in tree .flatten (struct_params )
148
- ]
149
- results = [
150
- ov_opset .result (r .output )
151
- for r in tree .flatten (struct_outputs )
152
- ]
153
- core = ov .Core ()
154
- ov_model = ov .Model (results = results , parameters = parameters )
155
- compile_ov_model = core .compile_model (ov_model , "CPU" )
156
- return get_outputs (inputs , struct_outputs , compile_ov_model )
151
+ if not hasattr (ov_infer , "compiled_model" ):
152
+ ov_infer .compiled_model = None
153
+ parameters = [
154
+ p .output .get_node () for p in tree .flatten (struct_params )
155
+ ]
156
+ results = [
157
+ ov_opset .result (r .output )
158
+ for r in tree .flatten (struct_outputs )
159
+ ]
160
+ core = ov .Core ()
161
+ ov_model = ov .Model (results = results , parameters = parameters )
162
+ for ov_input in ov_model .inputs :
163
+ rank = ov_input .get_partial_shape ().rank .get_length ()
164
+ ov_input .get_node ().set_partial_shape (
165
+ ov .PartialShape ([- 1 ] * rank )
166
+ )
167
+ ov_model .validate_nodes_and_infer_types ()
168
+ with tempfile .TemporaryDirectory () as tmpdir :
169
+ path = os .path .join (tmpdir , "ov_model.xml" )
170
+ ov .serialize (ov_model , path )
171
+ del ov_model
172
+ ov_model = core .read_model (path )
173
+ group_sizes = [128 , 64 , 16 , 4 ]
174
+ for group_size in group_sizes :
175
+ try :
176
+ final_model = compress_weights (
177
+ ov_model ,
178
+ mode = CompressWeightsMode .INT4_SYM ,
179
+ group_size = group_size ,
180
+ )
181
+ break
182
+ except Exception :
183
+ continue
184
+ ov_infer .compile_ov_model = core .compile_model (
185
+ final_model , "CPU"
186
+ )
187
+ return get_outputs (
188
+ inputs , struct_outputs , ov_infer .compile_ov_model
189
+ )
157
190
158
191
def wrapped_generate_function (inputs , stop_token_ids = None ):
159
192
inputs = tree .map_structure (ops .array , inputs )
0 commit comments