Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ jobs:
version: keras-3.8
- backend: jax
version: keras-nightly
- backend: openvino
version: keras-nightly
runs-on: ubuntu-latest
env:
KERAS_BACKEND: ${{ matrix.backend }}
Expand Down
53 changes: 53 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,22 @@
import keras
import pytest

# OpenVINO supported test paths
OPENVINO_SUPPORTED_PATHS = [
"keras-hub/integration_tests",
"keras_hub/src/models/gemma",
"keras_hub/src/models/gpt2",
"keras_hub/src/models/mistral",
"keras_hub/src/tokenizers",
]

# OpenVINO specific test skips
OPENVINO_SPECIFIC_SKIPPING_TESTS = {
"test_backbone_basics": "bfloat16 dtype not supported",
"test_score_loss": "Non-implemented roll operation",
"test_causal_lm_basics": "Missing ops and requires trainable backend",
}


def pytest_addoption(parser):
parser.addoption(
Expand Down Expand Up @@ -32,6 +48,15 @@ def pytest_addoption(parser):


def pytest_configure(config):
# Monkey-patch training methods for OpenVINO backend
if keras.config.backend() == "openvino":
keras.Model.fit = lambda *args, **kwargs: pytest.skip(
"Model.fit() not supported on OpenVINO backend"
)
keras.Model.train_on_batch = lambda *args, **kwargs: pytest.skip(
"Model.train_on_batch() not supported on OpenVINO backend"
)

# Verify that device has GPU and detected by backend
if config.getoption("--check_gpu"):
found_gpu = False
Expand Down Expand Up @@ -110,6 +135,34 @@ def pytest_collection_modifyitems(config, items):
if "kaggle_key_required" in item.keywords:
item.add_marker(kaggle_key_required)

# OpenVINO-specific test skipping
if keras.config.backend() == "openvino":
test_name = item.name.split("[")[0]

if test_name in OPENVINO_SPECIFIC_SKIPPING_TESTS:
item.add_marker(
pytest.mark.skipif(
True,
reason="OpenVINO: "
f"{OPENVINO_SPECIFIC_SKIPPING_TESTS[test_name]}",
)
)
continue

is_whitelisted = any(
item.nodeid.startswith(supported_path + "/")
or item.nodeid.startswith(supported_path + "::")
or item.nodeid == supported_path
for supported_path in OPENVINO_SUPPORTED_PATHS
)

if not is_whitelisted:
item.add_marker(
pytest.mark.skipif(
True, reason="OpenVINO: File/directory not in whitelist"
)
)


# Disable traceback filtering for quicker debugging of tests failures.
keras.config.disable_traceback_filtering()
11 changes: 11 additions & 0 deletions keras_hub/src/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ def make_generate_function(self):
return self.generate_function

self.generate_function = self.generate_step
if keras.config.backend() == "openvino":
from keras_hub.src.utils.openvino_utils import ov_infer

def wrapped_generate_function(inputs, stop_token_ids=None):
# Convert to numpy for OpenVINO backend
inputs = tree.map_structure(ops.array, inputs)
return ov_infer(
self, inputs, stop_token_ids, self.generate_step
)

self.generate_function = wrapped_generate_function
if keras.config.backend() == "torch":
import torch

Expand Down
12 changes: 6 additions & 6 deletions keras_hub/src/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ def unflatten_beams(x):
)
log_probs = flatten_beams(ops.repeat(log_probs, batch_size, axis=0))

def cond(prompt, cache, index, log_probs):
def cond(prompt, cache, index, mask, log_probs):
if stop_token_ids is None:
return True
return ops.convert_to_tensor(True, dtype="bool")
# Stop if all sequences have produced a *new* stop token.
end_tokens = any_equal(prompt, stop_token_ids, ~mask)
prompt_done = ops.any(end_tokens, axis=-1)
return ops.logical_not(ops.all(prompt_done))

def body(prompt, cache, index, log_probs):
def body(prompt, cache, index, mask, log_probs):
# Compute the softmax distribution for the next token.
logits, _, cache = next(prompt, cache, index)
vocab_size = ops.shape(logits)[-1]
Expand Down Expand Up @@ -150,12 +150,12 @@ def gather_beams(x):
next_token = next_token[:, None]
prompt = ops.slice_update(prompt, [0, index], next_token)
# Return the iteration of the loop state.
return (prompt, cache, index + 1, log_probs)
return (prompt, cache, index + 1, mask, log_probs)

prompt, _, _, log_probs = self.run_loop(
prompt, _, _, _, log_probs = self.run_loop(
cond=cond,
body=body,
loop_vars=(prompt, cache, index, log_probs),
loop_vars=(prompt, cache, index, mask, log_probs),
maximum_iterations=(max_length - index),
model=model,
)
Expand Down
14 changes: 8 additions & 6 deletions keras_hub/src/samplers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,18 @@ def __call__(
# `ops.while_loop` will not accept `None` as a value for `loop_vars`.
cache = () if cache is None else cache

def cond(prompt, cache, index):
# OpenVINO requires all parameters to be passed in the body.
# So we pass `mask` as well.
def cond(prompt, cache, index, mask):
if stop_token_ids is None:
return True
return ops.convert_to_tensor(True, dtype="bool")
# Stop if all sequences have produced a *new* id from
# stop_token_ids.
end_tokens = any_equal(prompt, stop_token_ids, ~mask)
prompt_done = ops.any(end_tokens, axis=-1)
return ops.logical_not(ops.all(prompt_done))

def body(prompt, cache, index):
def body(prompt, cache, index, mask):
# Compute the softmax distribution for the next token.
logits, _, cache = next(prompt, cache, index)
probabilities = self.compute_probabilities(logits)
Expand All @@ -115,12 +117,12 @@ def body(prompt, cache, index):
prompt = ops.slice_update(prompt, [0, index], next_token)

# Return the next prompt, cache and incremented index.
return (prompt, cache, index + 1)
return (prompt, cache, index + 1, mask)

prompt, _, _ = self.run_loop(
prompt, _, _, _ = self.run_loop(
cond,
body,
loop_vars=(prompt, cache, index),
loop_vars=(prompt, cache, index, mask),
maximum_iterations=(max_length - index),
model=model,
)
Expand Down
141 changes: 141 additions & 0 deletions keras_hub/src/utils/openvino_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from keras import tree

from keras_hub.src.utils.keras_utils import print_msg

try:
import openvino as ov
import openvino.opset14 as ov_opset
from openvino import Core
except ImportError:
ov = None
ov_opset = None
Core = None


_core = None


def get_core():
"""Get or create OpenVINO Core instance.

Returns:
openvino.Core: OpenVINO Core instance,
or None if OpenVINO not available.
"""
global _core
if _core is None and Core is not None:
_core = Core()
return _core


def get_device():
"""Detect and return the best available OpenVINO device.

Returns:
str: "GPU" if available, otherwise "CPU".
"""
core = get_core()
if core is None:
return "CPU"
return "GPU" if "GPU" in core.available_devices else "CPU"


def compile_model(struct_params, struct_outputs, device, model_dtype):
"""Compile OpenVINO model with dynamic shapes and precision hints.

Args:
struct_params: Model parameters structure.
struct_outputs: Model outputs structure.
device: Target device ("GPU" or "CPU").
model_dtype: Model precision ("f16" or "f32").

Returns:
Compiled OpenVINO model ready for inference.
"""
flat_params = tree.flatten(struct_params)
flat_outputs = tree.flatten(struct_outputs)
parameters = [p.output.get_node() for p in flat_params]
results = [ov_opset.result(r.output) for r in flat_outputs]
ov_model = ov.Model(results=results, parameters=parameters)
for ov_input in ov_model.inputs:
rank = ov_input.get_partial_shape().rank.get_length()
ov_input.get_node().set_partial_shape(ov.PartialShape([-1] * rank))
ov_model.validate_nodes_and_infer_types()
config = {"INFERENCE_PRECISION_HINT": model_dtype}
core = get_core()
if core is None:
raise RuntimeError("OpenVINO not available")
return core.compile_model(ov_model, device, config)


def get_outputs(inputs, struct_outputs, compiled_ov_model, unpack_singleton):
"""Execute compiled OpenVINO model and return structured outputs.

Args:
inputs: Input tensors for inference.
struct_outputs: Expected output structure.
compiled_ov_model: Compiled OpenVINO model.
unpack_singleton: Function to unpack singleton outputs.

Returns:
Structured model outputs matching expected format.
"""
flatten_inputs = tree.flatten(inputs)
raw = compiled_ov_model(flatten_inputs).to_tuple()
packed = tree.pack_sequence_as(struct_outputs, raw)
return unpack_singleton(packed)


def ov_infer(model, inputs, stop_token_ids, fn):
"""High-level OpenVINO inference with model reuse and compilation.

This function manages OpenVINO model compilation and caching. It reuses
existing compiled models when possible, or compiles new ones as needed.
Handles device detection and automatic precision selection.

Args:
model: Keras model with OpenVINO backend support.
inputs: Input tensors for inference.
stop_token_ids: Token IDs that should stop generation.
fn: Function to execute with the parameterized inputs.

Returns:
Model outputs from OpenVINO inference.
"""
device = get_device()

# Try to use existing compiled model for the same device
if (
getattr(model, "ov_compiled_model", None) is not None
and getattr(model, "ov_device", None) is not None
and device == model.ov_device
):
try:
return get_outputs(
inputs,
model.struct_outputs,
model.ov_compiled_model,
model._unpack_singleton,
)
except RuntimeError as e:
print_msg(
"WARNING: OpenVINO inference \033[1mFAILED\033[0m, "
"recompiling model and trying again.\n" + str(e)
)
model.ov_compiled_model = None
model.struct_outputs = None

# Compile a new model
struct_params = model._parameterize_data(inputs)
model.struct_outputs = fn(struct_params, stop_token_ids)
model.ov_device = device
model_dtype = "f16" if model.dtype in ("float16", "bfloat16") else "f32"
model.ov_compiled_model = compile_model(
struct_params, model.struct_outputs, device, model_dtype
)
return get_outputs(
inputs,
model.struct_outputs,
model.ov_compiled_model,
model._unpack_singleton,
)
Loading
Loading