@@ -58,6 +58,11 @@ class CausalLM(Task):
58
58
59
59
def __init__ (self , * args , ** kwargs ):
60
60
super ().__init__ (* args , ** kwargs )
61
+ # only OpenVINO needs these declarations
62
+ if keras .config .backend () == "openvino" :
63
+ self ._ov_models = {}
64
+ self .struct_outputs = None
65
+ self .ov_infer = None
61
66
62
67
def compile (
63
68
self ,
@@ -132,6 +137,144 @@ def make_generate_function(self):
132
137
return self .generate_function
133
138
134
139
self .generate_function = self .generate_step
140
+ if keras .config .backend () == "openvino" :
141
+ import os
142
+ import shutil
143
+
144
+ import numpy as np
145
+ import openvino as ov
146
+ import openvino .runtime .opset14 as ov_opset
147
+ from keras .src .backend .openvino .core import OPENVINO_DTYPES
148
+ from keras .src .backend .openvino .core import OpenVINOKerasTensor
149
+
150
+ def unpack_singleton (x ):
151
+ if isinstance (x , (list , tuple )) and len (x ) == 1 :
152
+ return x [0 ]
153
+ return x
154
+
155
+ def parameterize_inputs (inputs ):
156
+ if isinstance (inputs , (list , tuple )):
157
+ return [parameterize_inputs (e ) for e in inputs ]
158
+ elif isinstance (inputs , dict ):
159
+ return {
160
+ k : parameterize_inputs (v ) for k , v in inputs .items ()
161
+ }
162
+ elif isinstance (inputs , np .ndarray ):
163
+ ov_type = OPENVINO_DTYPES [str (inputs .dtype )]
164
+ ov_shape = list (inputs .shape )
165
+ param = ov_opset .parameter (shape = ov_shape , dtype = ov_type )
166
+ return OpenVINOKerasTensor (param .output (0 ))
167
+ elif isinstance (inputs , (int , np .integer )):
168
+ param = ov_opset .parameter (shape = [], dtype = ov .Type .i32 )
169
+ return OpenVINOKerasTensor (param .output (0 ))
170
+ elif isinstance (inputs , (float , np .floating )):
171
+ param = ov_opset .parameter (shape = [], dtype = ov .Type .f32 )
172
+ return OpenVINOKerasTensor (param .output (0 ))
173
+ else :
174
+ raise TypeError (f"Unknown input type: { type (inputs )} " )
175
+
176
+ def set_struct_outputs (inputs , fn ):
177
+ struct_params = parameterize_inputs (inputs )
178
+ self .struct_outputs = fn (struct_params )
179
+ return struct_params , self .struct_outputs
180
+
181
+ def get_outputs_from_model (inputs , model ):
182
+ flatten_inputs = tree .flatten (inputs )
183
+ assert OpenVINOKerasTensor not in inputs , (
184
+ "inputs should be numpy arrays"
185
+ )
186
+ outputs = model (flatten_inputs )
187
+ outputs = unpack_singleton (
188
+ tree .pack_sequence_as (
189
+ self .struct_outputs , outputs .to_tuple ()
190
+ )
191
+ )
192
+ return outputs
193
+
194
+ def get_model (inputs , fn , ov_model = None , compiled = False ):
195
+ config = {
196
+ "CACHE_DIR" : "openvino_cache" ,
197
+ }
198
+
199
+ struct_params , _ = set_struct_outputs (inputs , fn )
200
+
201
+ if ov_model is not None :
202
+ assert compiled , (
203
+ "if you pass a model, you should make compiled=True"
204
+ )
205
+ return ov .compile_model (ov_model , "CPU" , config )
206
+
207
+ parameters = [
208
+ p .output .get_node () for p in tree .flatten (struct_params )
209
+ ]
210
+ results = [
211
+ ov_opset .result (r .output )
212
+ for r in tree .flatten (self .struct_outputs )
213
+ ]
214
+
215
+ ov_model = ov .Model (results = results , parameters = parameters )
216
+ if not compiled :
217
+ return ov_model
218
+
219
+ return ov .compile_model (ov_model , "CPU" , config )
220
+
221
+ def compile_model_disc (inputs , fn , name ):
222
+ model_path = f"./run_dir/{ name } .xml"
223
+ if not os .path .exists (model_path ):
224
+ ov_model = get_model (inputs , fn )
225
+ ov .save_model (ov_model , model_path )
226
+ model = ov .Core ().read_model (model_path )
227
+ return get_model (inputs , fn , ov_model = model , compiled = True )
228
+
229
+ def ov_infer (
230
+ inputs ,
231
+ fn ,
232
+ cache = False ,
233
+ disc = False ,
234
+ name = None ,
235
+ ):
236
+ compiled_model = None
237
+ if cache :
238
+ assert name is not None , (
239
+ "you should provide name of the model being cached"
240
+ )
241
+ if self ._ov_models .get (name ) is None :
242
+ self ._ov_models [name ] = get_model (
243
+ inputs , fn , compiled = True
244
+ )
245
+ else :
246
+ set_struct_outputs (inputs , fn )
247
+ compiled_model = self ._ov_models [name ]
248
+ elif disc :
249
+ assert name is not None , (
250
+ "you should provide the name of thr model"
251
+ )
252
+ compiled_model = compile_model_disc (inputs , fn , name )
253
+ else :
254
+ compiled_model = get_model (inputs , fn , compiled = True )
255
+ outputs = get_outputs_from_model (inputs , compiled_model )
256
+ del compiled_model
257
+ return outputs
258
+
259
+ def delete_ov_cache ():
260
+ for path in ["openvino_cache" , "run_dir" ]:
261
+ if os .path .exists (path ):
262
+ shutil .rmtree (path , ignore_errors = True )
263
+
264
+ self .ov_infer = ov_infer
265
+
266
+ def wrapped_generate_function (inputs , stop_token_ids = None ):
267
+ final_outputs = []
268
+ os .makedirs ("./run_dir" , exist_ok = True )
269
+ for input in inputs :
270
+ outputs = self .generate_step (input , stop_token_ids )
271
+ for k , v in outputs .items ():
272
+ outputs [k ] = ops .convert_to_numpy (v )
273
+ final_outputs .append (outputs )
274
+ delete_ov_cache ()
275
+ return final_outputs
276
+
277
+ self .generate_function = wrapped_generate_function
135
278
if keras .config .backend () == "torch" :
136
279
import torch
137
280
@@ -386,7 +529,10 @@ def postprocess(x):
386
529
if strip_prompt :
387
530
outputs = [strip_prompt_function (generate (x ), x ) for x in inputs ]
388
531
else :
389
- outputs = [generate (x ) for x in inputs ]
532
+ if keras .config .backend () == "openvino" :
533
+ outputs = generate (inputs )
534
+ else :
535
+ outputs = [generate (x ) for x in inputs ]
390
536
391
537
if self .preprocessor is not None :
392
538
outputs = [postprocess (x ) for x in outputs ]
0 commit comments