Skip to content

Commit d748dd5

Browse files
remove disc and separate layers mechanism
1 parent 692ae90 commit d748dd5

File tree

2 files changed

+47
-72
lines changed

2 files changed

+47
-72
lines changed

keras_hub/src/models/causal_lm.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,9 @@ def make_generate_function(self):
138138

139139
self.generate_function = self.generate_step
140140
if keras.config.backend() == "openvino":
141-
import os
142-
import shutil
143-
144141
import numpy as np
145142
import openvino as ov
146-
import openvino.runtime.opset14 as ov_opset
143+
import openvino.runtime.opset15 as ov_opset
147144
from keras.src.backend.openvino.core import OPENVINO_DTYPES
148145
from keras.src.backend.openvino.core import OpenVINOKerasTensor
149146

@@ -192,17 +189,13 @@ def get_outputs_from_model(inputs, model):
192189
return outputs
193190

194191
def get_model(inputs, fn, ov_model=None, compiled=False):
195-
config = {
196-
"CACHE_DIR": "openvino_cache",
197-
}
198-
199192
struct_params, _ = set_struct_outputs(inputs, fn)
200193

201194
if ov_model is not None:
202195
assert compiled, (
203196
"if you pass a model, you should make compiled=True"
204197
)
205-
return ov.compile_model(ov_model, "CPU", config)
198+
return ov.compile_model(ov_model, "CPU")
206199

207200
parameters = [
208201
p.output.get_node() for p in tree.flatten(struct_params)
@@ -216,15 +209,7 @@ def get_model(inputs, fn, ov_model=None, compiled=False):
216209
if not compiled:
217210
return ov_model
218211

219-
return ov.compile_model(ov_model, "CPU", config)
220-
221-
def compile_model_disc(inputs, fn, name):
222-
model_path = f"./run_dir/{name}.xml"
223-
if not os.path.exists(model_path):
224-
ov_model = get_model(inputs, fn)
225-
ov.save_model(ov_model, model_path)
226-
model = ov.Core().read_model(model_path)
227-
return get_model(inputs, fn, ov_model=model, compiled=True)
212+
return ov.compile_model(ov_model, "CPU")
228213

229214
def ov_infer(
230215
inputs,
@@ -245,33 +230,21 @@ def ov_infer(
245230
else:
246231
set_struct_outputs(inputs, fn)
247232
compiled_model = self._ov_mem[name]
248-
elif disc:
249-
assert name is not None, (
250-
"you should provide the name of thr model"
251-
)
252-
compiled_model = compile_model_disc(inputs, fn, name)
253233
else:
254234
compiled_model = get_model(inputs, fn, compiled=True)
255235
outputs = get_outputs_from_model(inputs, compiled_model)
256236
del compiled_model
257237
return outputs
258238

259-
def delete_ov_cache():
260-
for path in ["openvino_cache", "run_dir"]:
261-
if os.path.exists(path):
262-
shutil.rmtree(path, ignore_errors=True)
263-
264239
self.ov_infer = ov_infer
265240

266241
def wrapped_generate_function(inputs, stop_token_ids=None):
267242
final_outputs = []
268-
os.makedirs("./run_dir", exist_ok=True)
269243
for input in inputs:
270244
outputs = self.generate_step(input, stop_token_ids)
271245
for k, v in outputs.items():
272246
outputs[k] = ops.convert_to_numpy(v)
273247
final_outputs.append(outputs)
274-
delete_ov_cache()
275248
return final_outputs
276249

277250
self.generate_function = wrapped_generate_function

keras_hub/src/models/gemma/gemma_causal_lm.py

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -196,31 +196,36 @@ def call_with_cache(
196196
the final hidden representation of the input tokens, and `cache` is
197197
the decoding cache.
198198
"""
199+
caches = []
200+
201+
use_openvino = keras.config.backend() == "openvino"
199202

200203
def embed_and_scale_tokens(token_ids):
201204
x = self.backbone.token_embedding(token_ids)
202205
return x * ops.cast(ops.sqrt(self.backbone.hidden_dim), x.dtype)
203206

204-
def make_apply_fn(layer):
205-
def apply_transformer_layer(inputs):
206-
x = inputs["x"]
207-
current_cache = inputs["current_cache"]
208-
index = inputs["cache_update_index"]
209-
x, next_cache = layer(
210-
x, cache=current_cache, cache_update_index=index
207+
def apply_transformer_layers(inputs):
208+
x = inputs["x"]
209+
current_cache = inputs["current_cache"]
210+
cache_update_index = inputs["cache_update_index"]
211+
for i, transformer_layer in enumerate(
212+
self.backbone.transformer_layers
213+
):
214+
current_cache = cache[:, i, ...]
215+
x, next_cache = transformer_layer(
216+
x,
217+
cache=current_cache,
218+
cache_update_index=cache_update_index,
211219
)
212-
return x, next_cache
213-
214-
return apply_transformer_layer
220+
caches.append(next_cache)
221+
return x, next_cache
215222

216223
def finalize_generation_step(inputs):
217224
x = self.backbone.layer_norm(inputs["x"])
218225
cache = ops.stack(inputs["caches"], axis=1)
219226
logits = self.backbone.token_embedding(x, reverse=True)
220227
return logits, x, cache
221228

222-
use_openvino = keras.config.backend() == "openvino"
223-
224229
if use_openvino:
225230
token_ids = ops.convert_to_numpy(token_ids)
226231
cache = ops.convert_to_numpy(cache)
@@ -233,56 +238,53 @@ def finalize_generation_step(inputs):
233238
)
234239
else:
235240
ov_cache = self._ov_mem.get("cache")
236-
if ov_cache is not None and cache.shape == ov_cache.shape:
241+
if ov_cache is not None and cache.shape == ov_cache.shape:
237242
return None, self._ov_mem["hidden_states"], ov_cache
238243
x = self.ov_infer(token_ids, embed_and_scale_tokens)
239244
else:
240245
x = embed_and_scale_tokens(token_ids)
241246

242-
caches = []
243-
for i, transformer_layer in enumerate(self.backbone.transformer_layers):
244-
current_cache = cache[:, i, ...]
245-
246-
inputs = {
247-
"x": x,
248-
"current_cache": current_cache,
249-
"cache_update_index": cache_update_index,
250-
}
251-
252-
apply_fn = make_apply_fn(transformer_layer)
253-
254-
if use_openvino:
255-
if token_ids.shape[1] == 1:
256-
x, next_cache = self.ov_infer(
257-
inputs,
258-
apply_fn,
259-
disc=True,
260-
name=f"layer_{i}",
261-
)
262-
else:
263-
x, next_cache = self.ov_infer(inputs, apply_fn)
247+
if use_openvino:
248+
if token_ids.shape[1] == 1:
249+
x, cache = self.ov_infer(
250+
{"x": x, "current_cache": cache, "cache_update_index": 0},
251+
apply_transformer_layers,
252+
cache=True,
253+
name="apply_transformer_layers",
254+
)
264255
else:
265-
x, next_cache = apply_fn(inputs)
266-
267-
caches.append(next_cache)
256+
x, cache = self.ov_infer(
257+
{"x": x, "current_cache": cache, "cache_update_index": 0},
258+
apply_transformer_layers,
259+
)
260+
self._ov_mem["cache"] = cache
261+
else:
262+
x, cache = apply_transformer_layers(
263+
{
264+
"x": x,
265+
"current_cache": cache,
266+
"cache_update_index": cache_update_index,
267+
}
268+
)
268269

269-
inputs = {"x": x, "caches": caches}
270270
if use_openvino:
271271
if token_ids.shape[1] == 1:
272272
logits, hidden_states, cache = self.ov_infer(
273-
inputs,
273+
{"x": x, "caches": caches},
274274
finalize_generation_step,
275275
cache=True,
276276
name="finalize_generation_step",
277277
)
278278
else:
279279
logits, hidden_states, cache = self.ov_infer(
280-
inputs, finalize_generation_step
280+
{"x": x, "caches": caches}, finalize_generation_step
281281
)
282282
self._ov_mem["cache"] = cache
283283
self._ov_mem["hidden_states"] = hidden_states
284284
else:
285-
logits, hidden_states, cache = finalize_generation_step(inputs)
285+
logits, hidden_states, cache = finalize_generation_step(
286+
{"x": x, "caches": caches}
287+
)
286288

287289
return logits, hidden_states, cache
288290

0 commit comments

Comments
 (0)