Skip to content

Commit 074f0c2

Browse files
[OpenVINO backend] supporting inference for gemma with ov backend
1 parent 54c3465 commit 074f0c2

File tree

7 files changed

+315
-18
lines changed

7 files changed

+315
-18
lines changed

.github/workflows/actions.yml

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,25 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
backend: [tensorflow, jax, torch]
19+
backend: [tensorflow, jax, torch, openvino]
2020
version: [keras-stable]
2121
include:
2222
- backend: jax
2323
version: keras-3.5
2424
- backend: jax
2525
version: keras-nightly
26+
- backend: openvino
27+
version: keras-stable
28+
python-version: '3.10'
2629
runs-on: ubuntu-latest
2730
env:
2831
KERAS_BACKEND: ${{ matrix.backend }}
2932
steps:
3033
- uses: actions/checkout@v4
31-
- name: Set up Python 3.9
34+
- name: Set up Python
3235
uses: actions/setup-python@v5
3336
with:
34-
python-version: 3.9
37+
python-version: ${{ matrix.python-version || '3.9' }}
3538
- name: Get pip cache dir
3639
id: pip-cache
3740
run: |
@@ -48,6 +51,10 @@ jobs:
4851
run: |
4952
pip install -r requirements.txt --progress-bar off
5053
pip install --no-deps -e "." --progress-bar off
54+
if [[ "${{ matrix.backend }}" == "openvino" ]]; then
55+
pip uninstall -y keras
56+
pip install git+https://github.com/keras-team/keras.git@master --upgrade --force-reinstall --progress-bar off
57+
fi
5158
- name: Pin Keras 3.5
5259
if: ${{ matrix.version == 'keras-3.5'}}
5360
run: |
@@ -60,11 +67,20 @@ jobs:
6067
pip install keras-nightly --progress-bar off
6168
- name: Test with pytest
6269
run: |
63-
pytest keras_hub/
70+
if [[ "${{ matrix.backend }}" == "openvino" ]]; then
71+
pytest keras_hub/src/models/gemma/gemma_causal_lm_test.py
72+
else
73+
pytest keras_hub/
74+
fi
6475
- name: Run integration tests
6576
run: |
6677
python pip_build.py --install
67-
cd integration_tests && pytest . -k "not NoTensorflow"
78+
cd integration_tests
79+
if [[ "${{ matrix.backend }}" == "openvino" ]]; then
80+
pytest . --ignore=basic_usage_test.py -k "not NoTensorflow"
81+
else
82+
pytest . -k "not NoTensorflow"
83+
fi
6884
- name: Run no tensorflow integration test
6985
if: ${{ matrix.backend != 'tensorflow'}}
7086
run: |

keras_hub/src/models/causal_lm.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ class CausalLM(Task):
5858

5959
def __init__(self, *args, **kwargs):
6060
super().__init__(*args, **kwargs)
61+
# only OpenVINO needs these declarations
62+
if keras.config.backend() == "openvino":
63+
self._ov_models = {}
64+
self.struct_outputs = None
65+
self.ov_infer = None
6166

6267
def compile(
6368
self,
@@ -132,6 +137,144 @@ def make_generate_function(self):
132137
return self.generate_function
133138

134139
self.generate_function = self.generate_step
140+
if keras.config.backend() == "openvino":
141+
import os
142+
import shutil
143+
144+
import numpy as np
145+
import openvino as ov
146+
import openvino.runtime.opset14 as ov_opset
147+
from keras.src.backend.openvino.core import OPENVINO_DTYPES
148+
from keras.src.backend.openvino.core import OpenVINOKerasTensor
149+
150+
def unpack_singleton(x):
151+
if isinstance(x, (list, tuple)) and len(x) == 1:
152+
return x[0]
153+
return x
154+
155+
def parameterize_inputs(inputs):
156+
if isinstance(inputs, (list, tuple)):
157+
return [parameterize_inputs(e) for e in inputs]
158+
elif isinstance(inputs, dict):
159+
return {
160+
k: parameterize_inputs(v) for k, v in inputs.items()
161+
}
162+
elif isinstance(inputs, np.ndarray):
163+
ov_type = OPENVINO_DTYPES[str(inputs.dtype)]
164+
ov_shape = list(inputs.shape)
165+
param = ov_opset.parameter(shape=ov_shape, dtype=ov_type)
166+
return OpenVINOKerasTensor(param.output(0))
167+
elif isinstance(inputs, (int, np.integer)):
168+
param = ov_opset.parameter(shape=[], dtype=ov.Type.i32)
169+
return OpenVINOKerasTensor(param.output(0))
170+
elif isinstance(inputs, (float, np.floating)):
171+
param = ov_opset.parameter(shape=[], dtype=ov.Type.f32)
172+
return OpenVINOKerasTensor(param.output(0))
173+
else:
174+
raise TypeError(f"Unknown input type: {type(inputs)}")
175+
176+
def set_struct_outputs(inputs, fn):
177+
struct_params = parameterize_inputs(inputs)
178+
self.struct_outputs = fn(struct_params)
179+
return struct_params, self.struct_outputs
180+
181+
def get_outputs_from_model(inputs, model):
182+
flatten_inputs = tree.flatten(inputs)
183+
assert OpenVINOKerasTensor not in inputs, (
184+
"inputs should be numpy arrays"
185+
)
186+
outputs = model(flatten_inputs)
187+
outputs = unpack_singleton(
188+
tree.pack_sequence_as(
189+
self.struct_outputs, outputs.to_tuple()
190+
)
191+
)
192+
return outputs
193+
194+
def get_model(inputs, fn, ov_model=None, compiled=False):
195+
config = {
196+
"CACHE_DIR": "openvino_cache",
197+
}
198+
199+
struct_params, _ = set_struct_outputs(inputs, fn)
200+
201+
if ov_model is not None:
202+
assert compiled, (
203+
"if you pass a model, you should make compiled=True"
204+
)
205+
return ov.compile_model(ov_model, "CPU", config)
206+
207+
parameters = [
208+
p.output.get_node() for p in tree.flatten(struct_params)
209+
]
210+
results = [
211+
ov_opset.result(r.output)
212+
for r in tree.flatten(self.struct_outputs)
213+
]
214+
215+
ov_model = ov.Model(results=results, parameters=parameters)
216+
if not compiled:
217+
return ov_model
218+
219+
return ov.compile_model(ov_model, "CPU", config)
220+
221+
def compile_model_disc(inputs, fn, name):
222+
model_path = f"./run_dir/{name}.xml"
223+
if not os.path.exists(model_path):
224+
ov_model = get_model(inputs, fn)
225+
ov.save_model(ov_model, model_path)
226+
model = ov.Core().read_model(model_path)
227+
return get_model(inputs, fn, ov_model=model, compiled=True)
228+
229+
def ov_infer(
230+
inputs,
231+
fn,
232+
cache=False,
233+
disc=False,
234+
name=None,
235+
):
236+
compiled_model = None
237+
if cache:
238+
assert name is not None, (
239+
"you should provide name of the model being cached"
240+
)
241+
if self._ov_models.get(name) is None:
242+
self._ov_models[name] = get_model(
243+
inputs, fn, compiled=True
244+
)
245+
else:
246+
set_struct_outputs(inputs, fn)
247+
compiled_model = self._ov_models[name]
248+
elif disc:
249+
assert name is not None, (
250+
"you should provide the name of thr model"
251+
)
252+
compiled_model = compile_model_disc(inputs, fn, name)
253+
else:
254+
compiled_model = get_model(inputs, fn, compiled=True)
255+
outputs = get_outputs_from_model(inputs, compiled_model)
256+
del compiled_model
257+
return outputs
258+
259+
def delete_ov_cache():
260+
for path in ["openvino_cache", "run_dir"]:
261+
if os.path.exists(path):
262+
shutil.rmtree(path, ignore_errors=True)
263+
264+
self.ov_infer = ov_infer
265+
266+
def wrapped_generate_function(inputs, stop_token_ids=None):
267+
final_outputs = []
268+
os.makedirs("./run_dir", exist_ok=True)
269+
for input in inputs:
270+
outputs = self.generate_step(input, stop_token_ids)
271+
for k, v in outputs.items():
272+
outputs[k] = ops.convert_to_numpy(v)
273+
final_outputs.append(outputs)
274+
delete_ov_cache()
275+
return final_outputs
276+
277+
self.generate_function = wrapped_generate_function
135278
if keras.config.backend() == "torch":
136279
import torch
137280

@@ -386,7 +529,10 @@ def postprocess(x):
386529
if strip_prompt:
387530
outputs = [strip_prompt_function(generate(x), x) for x in inputs]
388531
else:
389-
outputs = [generate(x) for x in inputs]
532+
if keras.config.backend() == "openvino":
533+
outputs = generate(inputs)
534+
else:
535+
outputs = [generate(x) for x in inputs]
390536

391537
if self.preprocessor is not None:
392538
outputs = [postprocess(x) for x in outputs]

keras_hub/src/models/gemma/gemma_causal_lm.py

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -196,22 +196,88 @@ def call_with_cache(
196196
the final hidden representation of the input tokens, and `cache` is
197197
the decoding cache.
198198
"""
199-
x = self.backbone.token_embedding(token_ids)
200-
x = x * ops.cast(ops.sqrt(self.backbone.hidden_dim), x.dtype)
201-
# Each decoder layer has a cache; we update them separately.
199+
200+
def embed_and_scale_tokens(token_ids):
201+
x = self.backbone.token_embedding(token_ids)
202+
return x * ops.cast(ops.sqrt(self.backbone.hidden_dim), x.dtype)
203+
204+
def make_apply_fn(layer):
205+
def apply_transformer_layer(inputs):
206+
x = inputs["x"]
207+
current_cache = inputs["current_cache"]
208+
index = inputs["cache_update_index"]
209+
x, next_cache = layer(
210+
x, cache=current_cache, cache_update_index=index
211+
)
212+
return x, next_cache
213+
214+
return apply_transformer_layer
215+
216+
def finalize_generation_step(inputs):
217+
x = self.backbone.layer_norm(inputs["x"])
218+
cache = ops.stack(inputs["caches"], axis=1)
219+
logits = self.backbone.token_embedding(x, reverse=True)
220+
return logits, x, cache
221+
222+
use_openvino = keras.config.backend() == "openvino"
223+
224+
if use_openvino:
225+
token_ids = ops.convert_to_numpy(token_ids)
226+
cache = ops.convert_to_numpy(cache)
227+
if token_ids.shape[1] == 1:
228+
x = self.ov_infer(
229+
token_ids,
230+
embed_and_scale_tokens,
231+
cache=True,
232+
name="embed_and_scale_tokens",
233+
)
234+
else:
235+
x = self.ov_infer(token_ids, embed_and_scale_tokens)
236+
else:
237+
x = embed_and_scale_tokens(token_ids)
238+
202239
caches = []
203240
for i, transformer_layer in enumerate(self.backbone.transformer_layers):
204241
current_cache = cache[:, i, ...]
205-
x, next_cache = transformer_layer(
206-
x,
207-
cache=current_cache,
208-
cache_update_index=cache_update_index,
209-
)
242+
inputs = {
243+
"x": x,
244+
"current_cache": current_cache,
245+
"cache_update_index": cache_update_index,
246+
}
247+
248+
apply_fn = make_apply_fn(transformer_layer)
249+
250+
if use_openvino:
251+
if token_ids.shape[1] == 1:
252+
x, next_cache = self.ov_infer(
253+
inputs,
254+
apply_fn,
255+
disc=True,
256+
name=f"layer_{i}",
257+
)
258+
else:
259+
x, next_cache = self.ov_infer(inputs, apply_fn)
260+
else:
261+
x, next_cache = apply_fn(inputs)
262+
210263
caches.append(next_cache)
211264

212-
cache = ops.stack(caches, axis=1)
213-
hidden_states = x = self.backbone.layer_norm(x)
214-
logits = self.backbone.token_embedding(x, reverse=True)
265+
inputs = {"x": x, "caches": caches}
266+
if use_openvino:
267+
if token_ids.shape[1] == 1:
268+
logits, hidden_states, cache = self.ov_infer(
269+
inputs,
270+
finalize_generation_step,
271+
cache=True,
272+
name="finalize_generation_step",
273+
)
274+
else:
275+
logits, hidden_states, cache = self.ov_infer(
276+
inputs, finalize_generation_step
277+
)
278+
else:
279+
logits, hidden_states, cache = finalize_generation_step(inputs)
280+
215281
return logits, hidden_states, cache
216282

217283
def _build_cache(self, token_ids):

0 commit comments

Comments
 (0)