@@ -196,31 +196,36 @@ def call_with_cache(
196
196
the final hidden representation of the input tokens, and `cache` is
197
197
the decoding cache.
198
198
"""
199
+ caches = []
200
+
201
+ use_openvino = keras .config .backend () == "openvino"
199
202
200
203
def embed_and_scale_tokens (token_ids ):
201
204
x = self .backbone .token_embedding (token_ids )
202
205
return x * ops .cast (ops .sqrt (self .backbone .hidden_dim ), x .dtype )
203
206
204
- def make_apply_fn (layer ):
205
- def apply_transformer_layer (inputs ):
206
- x = inputs ["x" ]
207
- current_cache = inputs ["current_cache" ]
208
- index = inputs ["cache_update_index" ]
209
- x , next_cache = layer (
210
- x , cache = current_cache , cache_update_index = index
207
+ def apply_transformer_layers (inputs ):
208
+ x = inputs ["x" ]
209
+ current_cache = inputs ["current_cache" ]
210
+ cache_update_index = inputs ["cache_update_index" ]
211
+ for i , transformer_layer in enumerate (
212
+ self .backbone .transformer_layers
213
+ ):
214
+ current_cache = cache [:, i , ...]
215
+ x , next_cache = transformer_layer (
216
+ x ,
217
+ cache = current_cache ,
218
+ cache_update_index = cache_update_index ,
211
219
)
212
- return x , next_cache
213
-
214
- return apply_transformer_layer
220
+ caches .append (next_cache )
221
+ return x , next_cache
215
222
216
223
def finalize_generation_step (inputs ):
217
224
x = self .backbone .layer_norm (inputs ["x" ])
218
225
cache = ops .stack (inputs ["caches" ], axis = 1 )
219
226
logits = self .backbone .token_embedding (x , reverse = True )
220
227
return logits , x , cache
221
228
222
- use_openvino = keras .config .backend () == "openvino"
223
-
224
229
if use_openvino :
225
230
token_ids = ops .convert_to_numpy (token_ids )
226
231
cache = ops .convert_to_numpy (cache )
@@ -233,56 +238,53 @@ def finalize_generation_step(inputs):
233
238
)
234
239
else :
235
240
ov_cache = self ._ov_mem .get ("cache" )
236
- if ov_cache is not None and cache .shape == ov_cache .shape :
241
+ if ov_cache is not None and cache .shape == ov_cache .shape :
237
242
return None , self ._ov_mem ["hidden_states" ], ov_cache
238
243
x = self .ov_infer (token_ids , embed_and_scale_tokens )
239
244
else :
240
245
x = embed_and_scale_tokens (token_ids )
241
246
242
- caches = []
243
- for i , transformer_layer in enumerate (self .backbone .transformer_layers ):
244
- current_cache = cache [:, i , ...]
245
-
246
- inputs = {
247
- "x" : x ,
248
- "current_cache" : current_cache ,
249
- "cache_update_index" : cache_update_index ,
250
- }
251
-
252
- apply_fn = make_apply_fn (transformer_layer )
253
-
254
- if use_openvino :
255
- if token_ids .shape [1 ] == 1 :
256
- x , next_cache = self .ov_infer (
257
- inputs ,
258
- apply_fn ,
259
- disc = True ,
260
- name = f"layer_{ i } " ,
261
- )
262
- else :
263
- x , next_cache = self .ov_infer (inputs , apply_fn )
247
+ if use_openvino :
248
+ if token_ids .shape [1 ] == 1 :
249
+ x , cache = self .ov_infer (
250
+ {"x" : x , "current_cache" : cache , "cache_update_index" : 0 },
251
+ apply_transformer_layers ,
252
+ cache = True ,
253
+ name = "apply_transformer_layers" ,
254
+ )
264
255
else :
265
- x , next_cache = apply_fn (inputs )
266
-
267
- caches .append (next_cache )
256
+ x , cache = self .ov_infer (
257
+ {"x" : x , "current_cache" : cache , "cache_update_index" : 0 },
258
+ apply_transformer_layers ,
259
+ )
260
+ self ._ov_mem ["cache" ] = cache
261
+ else :
262
+ x , cache = apply_transformer_layers (
263
+ {
264
+ "x" : x ,
265
+ "current_cache" : cache ,
266
+ "cache_update_index" : cache_update_index ,
267
+ }
268
+ )
268
269
269
- inputs = {"x" : x , "caches" : caches }
270
270
if use_openvino :
271
271
if token_ids .shape [1 ] == 1 :
272
272
logits , hidden_states , cache = self .ov_infer (
273
- inputs ,
273
+ { "x" : x , "caches" : caches } ,
274
274
finalize_generation_step ,
275
275
cache = True ,
276
276
name = "finalize_generation_step" ,
277
277
)
278
278
else :
279
279
logits , hidden_states , cache = self .ov_infer (
280
- inputs , finalize_generation_step
280
+ { "x" : x , "caches" : caches } , finalize_generation_step
281
281
)
282
282
self ._ov_mem ["cache" ] = cache
283
283
self ._ov_mem ["hidden_states" ] = hidden_states
284
284
else :
285
- logits , hidden_states , cache = finalize_generation_step (inputs )
285
+ logits , hidden_states , cache = finalize_generation_step (
286
+ {"x" : x , "caches" : caches }
287
+ )
286
288
287
289
return logits , hidden_states , cache
288
290
0 commit comments