From ab31c72f73a610fe639a6f50b3c79be1d61300e8 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Sat, 2 Aug 2025 22:57:42 +0300 Subject: [PATCH 01/17] [OpenVINO backend] support inference for Mistral & Gemma & GPT2 using OpenVINO backend --- .github/workflows/actions.yml | 21 +- conftest.py | 41 ++++ integration_tests/basic_usage_test.py | 4 + .../src/layers/modeling/alibi_bias_test.py | 2 + .../layers/modeling/anchor_generator_test.py | 1 + .../cached_multi_head_attention_test.py | 3 + .../src/layers/modeling/f_net_encoder_test.py | 3 + .../layers/modeling/masked_lm_head_test.py | 3 + .../modeling/position_embedding_test.py | 3 + .../modeling/reversible_embedding_test.py | 2 + .../layers/modeling/rotary_embedding_test.py | 3 + .../modeling/sine_position_encoding_test.py | 3 + .../token_and_position_embedding_test.py | 2 + .../modeling/transformer_decoder_test.py | 4 + .../modeling/transformer_encoder_test.py | 3 + keras_hub/src/models/causal_lm.py | 48 ++++- .../src/models/gemma/gemma_backbone_test.py | 2 + keras_hub/src/models/gemma/gemma_causal_lm.py | 3 + .../src/models/gemma/gemma_causal_lm_test.py | 9 + keras_hub/src/models/gemma/gemma_lora_test.py | 2 + .../src/models/gpt2/gpt2_backbone_test.py | 1 + keras_hub/src/models/gpt2/gpt2_causal_lm.py | 3 + .../src/models/gpt2/gpt2_causal_lm_test.py | 1 + .../models/mistral/mistral_backbone_test.py | 1 + .../src/models/mistral/mistral_causal_lm.py | 3 + .../models/mistral/mistral_causal_lm_test.py | 1 + keras_hub/src/samplers/sampler.py | 14 +- keras_hub/src/utils/openvino_utils.py | 126 ++++++++++++ keras_hub/src/utils/openvino_utils_test.py | 192 ++++++++++++++++++ keras_hub/src/utils/pipeline_model_test.py | 10 + openvino_excluded_concrete_tests.txt | 24 +++ openvino_excluded_tests.txt | 75 +++++++ requirements-common.txt | 1 + 33 files changed, 603 insertions(+), 11 deletions(-) create mode 100644 keras_hub/src/utils/openvino_utils.py create mode 100644 keras_hub/src/utils/openvino_utils_test.py create mode 100644 openvino_excluded_concrete_tests.txt create mode 100644 openvino_excluded_tests.txt diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 57a248d711..e3620d50c2 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -16,22 +16,25 @@ jobs: strategy: fail-fast: false matrix: - backend: [tensorflow, jax, torch] + backend: [tensorflow, jax, torch, openvino] version: [keras-stable] include: - backend: jax version: keras-3.5 - backend: jax version: keras-nightly + - backend: openvino + version: keras-stable + python-version: '3.10' runs-on: ubuntu-latest env: KERAS_BACKEND: ${{ matrix.backend }} steps: - uses: actions/checkout@v4 - - name: Set up Python 3.9 + - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: ${{ matrix.python-version || '3.9' }} - name: Get pip cache dir id: pip-cache run: | @@ -48,6 +51,10 @@ jobs: run: | pip install -r requirements.txt --progress-bar off pip install --no-deps -e "." --progress-bar off + if [[ "${{ matrix.backend }}" == "openvino" ]]; then + pip uninstall -y keras + pip install git+https://github.com/Mohamed-Ashraf273/keras.git@gsoc2025 --upgrade --force-reinstall --progress-bar off + fi - name: Pin Keras 3.5 if: ${{ matrix.version == 'keras-3.5'}} run: | @@ -60,7 +67,13 @@ jobs: pip install keras-nightly --progress-bar off - name: Test with pytest run: | - pytest keras_hub/ + if [ "${{ matrix.backend }}" == "openvino" ]; then + IGNORE_FILE="openvino_excluded_tests.txt" + IGNORE_ARGS=$(awk '{print "--ignore=" $0}' "$IGNORE_FILE") + else + IGNORE_ARGS="" + fi + pytest keras_hub/ $IGNORE_ARGS - name: Run integration tests run: | python pip_build.py --install diff --git a/conftest.py b/conftest.py index a5f40eb789..f5613930f8 100644 --- a/conftest.py +++ b/conftest.py @@ -2,6 +2,7 @@ import keras import pytest +from keras.src.backend import backend def pytest_addoption(parser): @@ -70,6 +71,10 @@ def pytest_configure(config): "markers", "kaggle_key_required: mark test needing a kaggle key", ) + config.addinivalue_line( + "markers", + "requires_trainable_backend: mark test for trainable backend only", + ) def pytest_collection_modifyitems(config, items): @@ -110,6 +115,42 @@ def pytest_collection_modifyitems(config, items): if "kaggle_key_required" in item.keywords: item.add_marker(kaggle_key_required) + openvino_skipped_tests = [] + if backend() == "openvino": + from pathlib import Path + + workspace_root = Path(__file__).resolve().parents[0] + file_path = workspace_root / "openvino_excluded_concrete_tests.txt" + with open(file_path, "r") as file: + openvino_skipped_tests = [ + line.strip() for line in file if line.strip() + ] + + requires_trainable_backend = pytest.mark.skipif( + backend() in ["openvino"], + reason="Trainer not implemented for OpenVINO backend.", + ) + + for item in items: + if "requires_trainable_backend" in item.keywords: + item.add_marker(requires_trainable_backend) + # also, skip concrete tests for openvino, listed in the special file + # this is more granular mechanism to exclude tests rather + # than using --ignore option + for skipped_test in openvino_skipped_tests: + if skipped_test in item.nodeid: + item.add_marker( + skip_if_backend( + "openvino", + "Not supported operation by openvino backend", + ) + ) + break + + +def skip_if_backend(given_backend, reason): + return pytest.mark.skipif(backend() == given_backend, reason=reason) + # Disable traceback filtering for quicker debugging of tests failures. keras.config.disable_traceback_filtering() diff --git a/integration_tests/basic_usage_test.py b/integration_tests/basic_usage_test.py index 7fd73bb9e5..75af52d577 100644 --- a/integration_tests/basic_usage_test.py +++ b/integration_tests/basic_usage_test.py @@ -6,6 +6,10 @@ import keras_hub +@unittest.skipIf( + keras.backend.backend() == "openvino", + "Skip for non-trainable backends like OpenVINO", +) class BasicUsageTest(unittest.TestCase): def test_transformer(self): # Tokenize some inputs with a binary label. diff --git a/keras_hub/src/layers/modeling/alibi_bias_test.py b/keras_hub/src/layers/modeling/alibi_bias_test.py index 627cede5a2..6e9d454472 100644 --- a/keras_hub/src/layers/modeling/alibi_bias_test.py +++ b/keras_hub/src/layers/modeling/alibi_bias_test.py @@ -1,4 +1,5 @@ import keras +import pytest from keras import ops from keras import random @@ -7,6 +8,7 @@ class AlibiBiasTest(TestCase): + @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): alibi_bias_max = 8 batch_size = 4 diff --git a/keras_hub/src/layers/modeling/anchor_generator_test.py b/keras_hub/src/layers/modeling/anchor_generator_test.py index e5918cdfda..4fd67a732e 100644 --- a/keras_hub/src/layers/modeling/anchor_generator_test.py +++ b/keras_hub/src/layers/modeling/anchor_generator_test.py @@ -14,6 +14,7 @@ reason="Bbox utils are not supported before keras < 3.8.0", ) class AnchorGeneratorTest(TestCase): + @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): images_shape = (8, 128, 128, 3) self.run_layer_test( diff --git a/keras_hub/src/layers/modeling/cached_multi_head_attention_test.py b/keras_hub/src/layers/modeling/cached_multi_head_attention_test.py index 6690667ade..7d41589a3c 100644 --- a/keras_hub/src/layers/modeling/cached_multi_head_attention_test.py +++ b/keras_hub/src/layers/modeling/cached_multi_head_attention_test.py @@ -1,3 +1,4 @@ +import pytest from keras import ops from keras import random @@ -8,6 +9,7 @@ class CachedMultiHeadAttentionTest(TestCase): + @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=CachedMultiHeadAttention, @@ -75,6 +77,7 @@ def call(outputs, cache): self.assertAllClose(output, no_loop_outputs) self.assertAllClose(output_cache, no_loop_cache) + @pytest.mark.requires_trainable_backend def test_training_propagation(self): batch_size = 2 seq_len = 5 diff --git a/keras_hub/src/layers/modeling/f_net_encoder_test.py b/keras_hub/src/layers/modeling/f_net_encoder_test.py index 0fe2d361dd..32848a82a3 100644 --- a/keras_hub/src/layers/modeling/f_net_encoder_test.py +++ b/keras_hub/src/layers/modeling/f_net_encoder_test.py @@ -1,3 +1,4 @@ +import pytest from keras import ops from keras import random @@ -6,6 +7,7 @@ class FNetEncoderTest(TestCase): + @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=FNetEncoder, @@ -31,6 +33,7 @@ def test_value_error_when_invalid_kernel_initializer(self): kernel_initializer="Invalid", ) + @pytest.mark.requires_trainable_backend def test_training_propagation(self): x = random.uniform(shape=(2, 4, 6)) layer = FNetEncoder( diff --git a/keras_hub/src/layers/modeling/masked_lm_head_test.py b/keras_hub/src/layers/modeling/masked_lm_head_test.py index 28da6dc6f9..0fc26449a0 100644 --- a/keras_hub/src/layers/modeling/masked_lm_head_test.py +++ b/keras_hub/src/layers/modeling/masked_lm_head_test.py @@ -1,3 +1,4 @@ +import pytest from keras import random from keras_hub.src.layers.modeling.masked_lm_head import MaskedLMHead @@ -8,6 +9,7 @@ class MaskedLMHeadTest(TestCase): + @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=MaskedLMHead, @@ -27,6 +29,7 @@ def test_layer_behaviors(self): expected_num_trainable_weights=6, ) + @pytest.mark.requires_trainable_backend def test_layer_behaviors_with_embedding(self): embedding = ReversibleEmbedding(100, 16) embedding.build((4, 10)) diff --git a/keras_hub/src/layers/modeling/position_embedding_test.py b/keras_hub/src/layers/modeling/position_embedding_test.py index d6e577c66e..ad4f4f5ca4 100644 --- a/keras_hub/src/layers/modeling/position_embedding_test.py +++ b/keras_hub/src/layers/modeling/position_embedding_test.py @@ -1,5 +1,6 @@ import keras import numpy as np +import pytest from keras import ops from keras import random @@ -15,6 +16,7 @@ def custom_init(shape, dtype=None): class PositionEmbeddingTest(TestCase): + @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=PositionEmbedding, @@ -26,6 +28,7 @@ def test_layer_behaviors(self): expected_num_trainable_weights=1, ) + @pytest.mark.requires_trainable_backend def test_layer_behaviors_4d(self): self.run_layer_test( cls=PositionEmbedding, diff --git a/keras_hub/src/layers/modeling/reversible_embedding_test.py b/keras_hub/src/layers/modeling/reversible_embedding_test.py index 4482854449..be7b69ccda 100644 --- a/keras_hub/src/layers/modeling/reversible_embedding_test.py +++ b/keras_hub/src/layers/modeling/reversible_embedding_test.py @@ -2,6 +2,7 @@ import keras import numpy as np +import pytest from absl.testing import parameterized from keras import ops from keras import random @@ -17,6 +18,7 @@ class ReversibleEmbeddingTest(TestCase): ("tie_weights", True), ("untie_weights", False), ) + @pytest.mark.requires_trainable_backend def test_layer_behaviors_tied(self, tie_weights): self.run_layer_test( cls=ReversibleEmbedding, diff --git a/keras_hub/src/layers/modeling/rotary_embedding_test.py b/keras_hub/src/layers/modeling/rotary_embedding_test.py index 35c2cc1356..5434e8b477 100644 --- a/keras_hub/src/layers/modeling/rotary_embedding_test.py +++ b/keras_hub/src/layers/modeling/rotary_embedding_test.py @@ -1,5 +1,6 @@ import keras import numpy as np +import pytest from keras import ops from keras import random @@ -8,6 +9,7 @@ class RotaryEmbeddingTest(TestCase): + @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=RotaryEmbedding, @@ -21,6 +23,7 @@ def test_layer_behaviors(self): expected_output_shape=(2, 4, 6), ) + @pytest.mark.requires_trainable_backend def test_layer_behaviors_4d(self): self.run_layer_test( cls=RotaryEmbedding, diff --git a/keras_hub/src/layers/modeling/sine_position_encoding_test.py b/keras_hub/src/layers/modeling/sine_position_encoding_test.py index eb0aeb2ff3..fd29c51fc1 100644 --- a/keras_hub/src/layers/modeling/sine_position_encoding_test.py +++ b/keras_hub/src/layers/modeling/sine_position_encoding_test.py @@ -1,4 +1,5 @@ import keras +import pytest from keras import ops from keras import random @@ -9,6 +10,7 @@ class SinePositionEncodingTest(TestCase): + @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=SinePositionEncoding, @@ -19,6 +21,7 @@ def test_layer_behaviors(self): expected_output_shape=(2, 4, 6), ) + @pytest.mark.requires_trainable_backend def test_layer_behaviors_4d(self): self.run_layer_test( cls=SinePositionEncoding, diff --git a/keras_hub/src/layers/modeling/token_and_position_embedding_test.py b/keras_hub/src/layers/modeling/token_and_position_embedding_test.py index f0ef202aed..fae2a73151 100644 --- a/keras_hub/src/layers/modeling/token_and_position_embedding_test.py +++ b/keras_hub/src/layers/modeling/token_and_position_embedding_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from keras import ops from keras import random @@ -9,6 +10,7 @@ class TokenAndPositionEmbeddingTest(TestCase): + @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=TokenAndPositionEmbedding, diff --git a/keras_hub/src/layers/modeling/transformer_decoder_test.py b/keras_hub/src/layers/modeling/transformer_decoder_test.py index 7cbd32bed8..2af91292c2 100644 --- a/keras_hub/src/layers/modeling/transformer_decoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_decoder_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import ops from keras import random @@ -11,6 +12,7 @@ class TransformerDecoderTest(TestCase): ("without_norm_first", False), ("with_norm_first", True), ) + @pytest.mark.requires_trainable_backend def test_layer_behaviors(self, normalize_first): self.run_layer_test( cls=TransformerDecoder, @@ -34,6 +36,7 @@ def test_layer_behaviors(self, normalize_first): ("without_norm_first", False), ("with_norm_first", True), ) + @pytest.mark.requires_trainable_backend def test_layer_behaviors_with_cross_attention(self, normalize_first): self.run_layer_test( cls=TransformerDecoder, @@ -89,6 +92,7 @@ def test_value_error_when_invalid_kernel_inititalizer(self): kernel_initializer="Invalid", ) + @pytest.mark.requires_trainable_backend def test_training_propagation(self): decoder = TransformerDecoder( intermediate_dim=4, diff --git a/keras_hub/src/layers/modeling/transformer_encoder_test.py b/keras_hub/src/layers/modeling/transformer_encoder_test.py index a682af157c..e9c58bc01c 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_encoder_test.py @@ -1,4 +1,5 @@ import keras +import pytest from absl.testing import parameterized from keras import ops from keras import random @@ -12,6 +13,7 @@ class TransformerEncoderTest(TestCase): ("without_norm_first", False), ("with_norm_first", True), ) + @pytest.mark.requires_trainable_backend def test_layer_behaviors(self, normalize_first): self.run_layer_test( cls=TransformerEncoder, @@ -69,6 +71,7 @@ def test_value_error_when_invalid_kernel_inititalizer(self): kernel_initializer="Invalid", ) + @pytest.mark.requires_trainable_backend def test_training_propagation(self): encoder = TransformerEncoder( intermediate_dim=4, diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index 0e31d2c5a2..30b8198eb5 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -132,7 +132,53 @@ def make_generate_function(self): return self.generate_function self.generate_function = self.generate_step - if keras.config.backend() == "torch": + if keras.config.backend() == "openvino": + import openvino as ov + import openvino.runtime.opset14 as ov_opset + + from keras_hub.src.utils.openvino_utils import get_outputs + from keras_hub.src.utils.openvino_utils import get_struct_outputs + + def ov_infer(inputs, stop_token_ids, fn): + struct_params, struct_outputs = get_struct_outputs( + inputs, stop_token_ids, fn + ) + # Try using the existing compiled model + if self.ov_compiled_model is not None: + try: + return get_outputs( + inputs, struct_outputs, self.ov_compiled_model + ) + except Exception: + # Fall through to recompilation if inference fails + pass + # Rebuild and compile the OpenVINO model + parameters = [ + p.output.get_node() for p in tree.flatten(struct_params) + ] + results = [ + ov_opset.result(r.output) + for r in tree.flatten(struct_outputs) + ] + ov_model = ov.Model(results=results, parameters=parameters) + for ov_input in ov_model.inputs: + rank = ov_input.get_partial_shape().rank.get_length() + ov_input.get_node().set_partial_shape( + ov.PartialShape([-1] * rank) + ) + ov_model.validate_nodes_and_infer_types() + core = ov.Core() + self.ov_compiled_model = core.compile_model(ov_model, "CPU") + return get_outputs( + inputs, struct_outputs, self.ov_compiled_model + ) + + def wrapped_generate_function(inputs, stop_token_ids=None): + inputs = tree.map_structure(ops.array, inputs) + return ov_infer(inputs, stop_token_ids, self.generate_step) + + self.generate_function = wrapped_generate_function + elif keras.config.backend() == "torch": import torch def wrapped_generate_function( diff --git a/keras_hub/src/models/gemma/gemma_backbone_test.py b/keras_hub/src/models/gemma/gemma_backbone_test.py index b5f8575332..7279f53323 100644 --- a/keras_hub/src/models/gemma/gemma_backbone_test.py +++ b/keras_hub/src/models/gemma/gemma_backbone_test.py @@ -23,6 +23,7 @@ def setUp(self): "padding_mask": ops.ones((2, 5), dtype="int32"), } + @pytest.mark.requires_trainable_backend def test_backbone_basics(self): self.run_backbone_test( cls=GemmaBackbone, @@ -180,6 +181,7 @@ def setUp(self): "padding_mask": ops.ones((2, 10), dtype="int32"), } + @pytest.mark.requires_trainable_backend def test_backbone_basics(self): self.run_backbone_test( cls=GemmaBackbone, diff --git a/keras_hub/src/models/gemma/gemma_causal_lm.py b/keras_hub/src/models/gemma/gemma_causal_lm.py index cf2c98c23c..07c6f1710c 100644 --- a/keras_hub/src/models/gemma/gemma_causal_lm.py +++ b/keras_hub/src/models/gemma/gemma_causal_lm.py @@ -258,6 +258,9 @@ def next(prompt, cache, index): cache_update_index = index - 1 batch_size = ops.shape(prompt)[0] prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + if keras.config.backend() == "openvino": + # Avoid returning dynamic shape by openvino slice + prompt = ops.reshape(prompt, [batch_size, 1]) logits, hidden_states, cache = self.call_with_cache( prompt, cache, diff --git a/keras_hub/src/models/gemma/gemma_causal_lm_test.py b/keras_hub/src/models/gemma/gemma_causal_lm_test.py index 7885d502cc..88ea2174ff 100644 --- a/keras_hub/src/models/gemma/gemma_causal_lm_test.py +++ b/keras_hub/src/models/gemma/gemma_causal_lm_test.py @@ -52,6 +52,7 @@ def setUp(self): self.train_data = (["the quick brown fox", "the quick brown fox"],) self.input_data = self.preprocessor(*self.train_data)[0] + @pytest.mark.requires_trainable_backend def test_causal_lm_basics(self): self.run_task_test( cls=GemmaCausalLM, @@ -60,6 +61,14 @@ def test_causal_lm_basics(self): expected_output_shape=(2, 8, 11), ) + # Note: To enable this test for OpenVINO, + # the issue causing long execution time must be resolved. + # See related discussion for details: + # https://github.com/openvinotoolkit/openvino/pull/31482 + @pytest.mark.skipif( + keras.config.backend() == "openvino", + reason="Skip for openvino it takes long time", + ) def test_cache_correctness(self): token_ids = self.input_data["token_ids"] padding_mask = ops.ones_like(self.input_data["padding_mask"]) diff --git a/keras_hub/src/models/gemma/gemma_lora_test.py b/keras_hub/src/models/gemma/gemma_lora_test.py index 256c4cf5fd..4d0da24523 100644 --- a/keras_hub/src/models/gemma/gemma_lora_test.py +++ b/keras_hub/src/models/gemma/gemma_lora_test.py @@ -1,11 +1,13 @@ import os import numpy as np +import pytest from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone from keras_hub.src.tests.test_case import TestCase +@pytest.mark.requires_trainable_backend class GemmaLoraTest(TestCase): def setUp(self): self._init_kwargs = { diff --git a/keras_hub/src/models/gpt2/gpt2_backbone_test.py b/keras_hub/src/models/gpt2/gpt2_backbone_test.py index 2ae887fbbc..6cf95ecc4c 100644 --- a/keras_hub/src/models/gpt2/gpt2_backbone_test.py +++ b/keras_hub/src/models/gpt2/gpt2_backbone_test.py @@ -20,6 +20,7 @@ def setUp(self): "padding_mask": ops.ones((2, 5), dtype="int32"), } + @pytest.mark.requires_trainable_backend def test_backbone_basics(self): self.run_backbone_test( cls=GPT2Backbone, diff --git a/keras_hub/src/models/gpt2/gpt2_causal_lm.py b/keras_hub/src/models/gpt2/gpt2_causal_lm.py index 7f29d4ebd8..dc5c4a8ba0 100644 --- a/keras_hub/src/models/gpt2/gpt2_causal_lm.py +++ b/keras_hub/src/models/gpt2/gpt2_causal_lm.py @@ -246,6 +246,9 @@ def next(prompt, cache, index): cache_update_index = index - 1 batch_size = ops.shape(prompt)[0] prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + if keras.config.backend() == "openvino": + # Avoid returning dynamic shape by openvino slice + prompt = ops.reshape(prompt, [batch_size, 1]) logits, hidden_states, cache = self.call_with_cache( prompt, cache, diff --git a/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py b/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py index 0f6315bea6..deee70d3d1 100644 --- a/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py +++ b/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py @@ -39,6 +39,7 @@ def setUp(self): self.train_data = ([" airplane at airport", " airplane at airport"],) self.input_data = self.preprocessor(*self.train_data)[0] + @pytest.mark.requires_trainable_backend def test_causal_lm_basics(self): self.run_task_test( cls=GPT2CausalLM, diff --git a/keras_hub/src/models/mistral/mistral_backbone_test.py b/keras_hub/src/models/mistral/mistral_backbone_test.py index ffb6e7ef20..d52784b85a 100644 --- a/keras_hub/src/models/mistral/mistral_backbone_test.py +++ b/keras_hub/src/models/mistral/mistral_backbone_test.py @@ -21,6 +21,7 @@ def setUp(self): "padding_mask": ops.ones((2, 5), dtype="int32"), } + @pytest.mark.requires_trainable_backend def test_backbone_basics(self): self.run_backbone_test( cls=MistralBackbone, diff --git a/keras_hub/src/models/mistral/mistral_causal_lm.py b/keras_hub/src/models/mistral/mistral_causal_lm.py index d28a7cad26..0cacb191fa 100644 --- a/keras_hub/src/models/mistral/mistral_causal_lm.py +++ b/keras_hub/src/models/mistral/mistral_causal_lm.py @@ -145,6 +145,9 @@ def next(prompt, cache, index): cache_update_index = index - 1 batch_size = ops.shape(prompt)[0] prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + if keras.config.backend() == "openvino": + # Avoid returning dynamic shape by openvino slice + prompt = ops.reshape(prompt, [batch_size, 1]) logits, hidden_states, cache = self.call_with_cache( prompt, cache, diff --git a/keras_hub/src/models/mistral/mistral_causal_lm_test.py b/keras_hub/src/models/mistral/mistral_causal_lm_test.py index 8a6bd42434..e1203b6641 100644 --- a/keras_hub/src/models/mistral/mistral_causal_lm_test.py +++ b/keras_hub/src/models/mistral/mistral_causal_lm_test.py @@ -39,6 +39,7 @@ def setUp(self): self.train_data = (["the quick brown fox", "the earth is round"],) self.input_data = self.preprocessor(*self.train_data)[0] + @pytest.mark.requires_trainable_backend def test_causal_lm_basics(self): self.run_task_test( cls=MistralCausalLM, diff --git a/keras_hub/src/samplers/sampler.py b/keras_hub/src/samplers/sampler.py index e3dd2627ee..44c4168375 100644 --- a/keras_hub/src/samplers/sampler.py +++ b/keras_hub/src/samplers/sampler.py @@ -92,16 +92,18 @@ def __call__( # `ops.while_loop` will not accept `None` as a value for `loop_vars`. cache = () if cache is None else cache - def cond(prompt, cache, index): + # OpenVINO requires all parameters to be passed in the body. + # So we pass `mask` as well. + def cond(prompt, cache, index, mask): if stop_token_ids is None: - return True + return ops.convert_to_tensor(True, dtype="bool") # Stop if all sequences have produced a *new* id from # stop_token_ids. end_tokens = any_equal(prompt, stop_token_ids, ~mask) prompt_done = ops.any(end_tokens, axis=-1) return ops.logical_not(ops.all(prompt_done)) - def body(prompt, cache, index): + def body(prompt, cache, index, mask): # Compute the softmax distribution for the next token. logits, _, cache = next(prompt, cache, index) probabilities = self.compute_probabilities(logits) @@ -115,12 +117,12 @@ def body(prompt, cache, index): prompt = ops.slice_update(prompt, [0, index], next_token) # Return the next prompt, cache and incremented index. - return (prompt, cache, index + 1) + return (prompt, cache, index + 1, mask) - prompt, _, _ = self.run_loop( + prompt, _, _, _ = self.run_loop( cond, body, - loop_vars=(prompt, cache, index), + loop_vars=(prompt, cache, index, mask), maximum_iterations=(max_length - index), model=model, ) diff --git a/keras_hub/src/utils/openvino_utils.py b/keras_hub/src/utils/openvino_utils.py new file mode 100644 index 0000000000..3633bdf8cf --- /dev/null +++ b/keras_hub/src/utils/openvino_utils.py @@ -0,0 +1,126 @@ +import numpy as np +import openvino as ov +import openvino.runtime.opset14 as ov_opset +from keras import ops +from keras import tree + +OPENVINO_DTYPES = { + "float16": ov.Type.f16, + "float32": ov.Type.f32, + "float64": ov.Type.f64, + "uint8": ov.Type.u8, + "uint16": ov.Type.u16, + "uint32": ov.Type.u32, + "uint64": ov.Type.u64, + "int8": ov.Type.i8, + "int16": ov.Type.i16, + "int32": ov.Type.i32, + "int64": ov.Type.i64, + "bfloat16": ov.Type.bf16, + "bool": ov.Type.boolean, + "float8_e4m3fn": ov.Type.f8e4m3, + "float8_e5m2": ov.Type.f8e5m2, + "string": ov.Type.string, +} + + +def unpack_singleton(x): + if isinstance(x, (list, tuple)) and len(x) == 1: + return x[0] + return x + + +def parameterize_inputs(inputs, prefix=""): + """ + Recursively converts input structures (dict, list, tuple, or scalars) into + OpenVINO Parameter nodes, preserving structure and assigning friendly names. + + Args: + inputs (Union[dict, list, tuple, np.ndarray, int, float]): + Input data structure or value to parameterize. + prefix (str): Prefix for naming OpenVINO parameter nodes. + + Returns: + Structure of the same form as `inputs`, but with each input replaced + by an OpenVINO-compatible tensor (converted parameter). + + Raises: + TypeError: If the input type is not supported. + """ + if isinstance(inputs, (list, tuple)): + return [ + parameterize_inputs(e, f"{prefix}{i}") for i, e in enumerate(inputs) + ] + elif isinstance(inputs, dict): + return {k: parameterize_inputs(v, k) for k, v in inputs.items()} + elif isinstance(inputs, np.ndarray): + ov_type = OPENVINO_DTYPES[str(inputs.dtype)] + ov_shape = list(inputs.shape) + param = ov_opset.parameter(shape=ov_shape, dtype=ov_type) + param.set_friendly_name(prefix) + return ops.convert_to_tensor(param.output(0)) + elif isinstance(inputs, (int, np.integer)): + param = ov_opset.parameter(shape=[], dtype=ov.Type.i32) + param.set_friendly_name(prefix) + return ops.convert_to_tensor(param.output(0)) + elif isinstance(inputs, (float, np.floating)): + param = ov_opset.parameter(shape=[], dtype=ov.Type.f32) + param.set_friendly_name(prefix) + return ops.convert_to_tensor(param.output(0)) + else: + raise TypeError(f"Unknown input type: {type(inputs)}") + + +def get_struct_outputs(inputs, stop_token_ids, fn): + """ + Prepares OpenVINO input parameters and calls the + user-defined generation function. + + Args: + inputs (dict or nested structure): Original input data + stop_token_ids (Any): Stop token information passed to + the model's generation step. + fn (Callable): A function representing a single generation + step that accepts parameterized inputs and returns structured outputs. + + Returns: + Tuple: (parameterized_inputs, struct_outputs) + - parameterized_inputs: OpenVINO parameter structure + for model compilation. + - struct_outputs: The output structure returned by + the generation function. + """ + struct_params = parameterize_inputs(inputs) + struct_outputs = fn(struct_params, stop_token_ids) + return struct_params, struct_outputs + + +def get_outputs(inputs, struct_outputs, compile_ov_model): + """ + Executes the OpenVINO compiled model with the given + inputs and reconstructs the output structure + to match `struct_outputs`. + + Args: + inputs (dict or nested structure): Original input data. + struct_outputs (Any): The structure that defines + how to reconstruct model outputs. + compile_ov_model (Callable): The compiled OpenVINO + model object with a `__call__` method. + + Returns: + The model output reconstructed to + match the structure of `struct_outputs`. + + Raises: + ValueError: If any of the inputs are still tensors. + """ + flatten_inputs = tree.flatten(inputs) + for input in flatten_inputs: + if ops.is_tensor(input): + raise ValueError("inputs should be numpy arrays") + outputs = compile_ov_model(flatten_inputs) + outputs = unpack_singleton( + tree.pack_sequence_as(struct_outputs, outputs.to_tuple()) + ) + return outputs diff --git a/keras_hub/src/utils/openvino_utils_test.py b/keras_hub/src/utils/openvino_utils_test.py new file mode 100644 index 0000000000..5d52c15282 --- /dev/null +++ b/keras_hub/src/utils/openvino_utils_test.py @@ -0,0 +1,192 @@ +import numpy as np +import openvino as ov +import pytest +from keras import backend +from keras import ops + +from keras_hub.src.tests.test_case import TestCase +from keras_hub.src.utils.openvino_utils import OPENVINO_DTYPES +from keras_hub.src.utils.openvino_utils import get_outputs +from keras_hub.src.utils.openvino_utils import get_struct_outputs +from keras_hub.src.utils.openvino_utils import parameterize_inputs +from keras_hub.src.utils.openvino_utils import unpack_singleton + + +@pytest.mark.skipif( + backend.backend() != "openvino", + reason="OpenVINO is required for these tests", +) +class TestOpenVinoUtils(TestCase): + def test_openvino_dtypes_mapping(self): + self.assertIn("float32", OPENVINO_DTYPES) + self.assertIn("int32", OPENVINO_DTYPES) + self.assertIn("bool", OPENVINO_DTYPES) + self.assertEqual(OPENVINO_DTYPES["float32"], ov.Type.f32) + self.assertEqual(OPENVINO_DTYPES["int32"], ov.Type.i32) + self.assertEqual(OPENVINO_DTYPES["bool"], ov.Type.boolean) + + def test_unpack_singleton_single_element_list(self): + result = unpack_singleton([42]) + self.assertEqual(result, 42) + + def test_unpack_singleton_single_element_tuple(self): + result = unpack_singleton((42,)) + self.assertEqual(result, 42) + + def test_unpack_singleton_multiple_elements(self): + input_list = [1, 2, 3] + result = unpack_singleton(input_list) + self.assertEqual(result, input_list) + + def test_unpack_singleton_empty_list(self): + input_list = [] + result = unpack_singleton(input_list) + self.assertEqual(result, input_list) + + def test_unpack_singleton_non_sequence(self): + result = unpack_singleton(42) + self.assertEqual(result, 42) + + def test_parameterize_inputs_numpy_array(self): + input_array = np.array([1, 2, 3], dtype=np.float32) + result = parameterize_inputs(input_array) + self.assertTrue(ops.is_tensor(result)) + + def test_parameterize_inputs_different_dtypes(self): + test_cases = [ + (np.array([1, 2, 3], dtype=np.int32), np.int32), + (np.array([1.0, 2.0, 3.0], dtype=np.float32), np.float32), + (np.array([1, 2, 3], dtype=np.int64), np.int64), + ] + + for input_array, expected_dtype in test_cases: + result = parameterize_inputs(input_array) + self.assertTrue(ops.is_tensor(result)) + + def test_parameterize_inputs_list(self): + input_list = [ + np.array([1, 2, 3], dtype=np.float32), + np.array([4, 5, 6], dtype=np.int32), + ] + result = parameterize_inputs(input_list) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + self.assertTrue(all(ops.is_tensor(r) for r in result)) + + def test_parameterize_inputs_tuple(self): + input_tuple = ( + np.array([1, 2, 3], dtype=np.float32), + np.array([4, 5, 6], dtype=np.int32), + ) + result = parameterize_inputs(input_tuple) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + self.assertTrue(all(ops.is_tensor(r) for r in result)) + + def test_parameterize_inputs_dict(self): + input_dict = { + "a": np.array([1, 2, 3], dtype=np.float32), + "b": np.array([4, 5, 6], dtype=np.int32), + } + result = parameterize_inputs(input_dict) + self.assertIsInstance(result, dict) + self.assertEqual(set(result.keys()), {"a", "b"}) + self.assertTrue(all(ops.is_tensor(v) for v in result.values())) + + def test_parameterize_inputs_integer(self): + result = parameterize_inputs(42) + self.assertTrue(ops.is_tensor(result)) + + def test_parameterize_inputs_numpy_integer(self): + result = parameterize_inputs(np.int32(42)) + self.assertTrue(ops.is_tensor(result)) + + def test_parameterize_inputs_float(self): + result = parameterize_inputs(3.14) + self.assertTrue(ops.is_tensor(result)) + + def test_parameterize_inputs_numpy_float(self): + result = parameterize_inputs(np.float32(3.14)) + self.assertTrue(ops.is_tensor(result)) + + def test_parameterize_inputs_unsupported_type(self): + with self.assertRaisesRegex(TypeError, "Unknown input type"): + parameterize_inputs("unsupported_string") + + def test_parameterize_inputs_nested_structure(self): + """Test parameterizing nested structures.""" + nested_input = { + "list": [np.array([1, 2], dtype=np.float32), 42], + "dict": {"nested": np.array([3, 4], dtype=np.int32)}, + } + result = parameterize_inputs(nested_input) + self.assertIsInstance(result, dict) + self.assertIsInstance(result["list"], list) + self.assertIsInstance(result["dict"], dict) + self.assertTrue(ops.is_tensor(result["list"][0])) + self.assertTrue(ops.is_tensor(result["list"][1])) + self.assertTrue(ops.is_tensor(result["dict"]["nested"])) + + def test_get_struct_outputs(self): + inputs = np.array([1, 2, 3], dtype=np.float32) + stop_token_ids = [0, 1] + + def mock_fn(params, stop_tokens): + return params # Simple mock that returns the params + + struct_params, struct_outputs = get_struct_outputs( + inputs, stop_token_ids, mock_fn + ) + self.assertTrue(ops.is_tensor(struct_params)) + self.assertTrue(ops.is_tensor(struct_outputs)) + + def test_get_outputs_with_tensor_input_raises_error(self): + inputs = [ops.convert_to_tensor(np.array([1, 2, 3]))] + struct_outputs = np.array([1, 2, 3]) + + def mock_compile_model(inputs): + class MockResult: + def to_tuple(self): + return (np.array([1, 2, 3]),) + + return MockResult() + + with self.assertRaisesRegex( + ValueError, "inputs should be numpy arrays" + ): + get_outputs(inputs, struct_outputs, mock_compile_model) + + def test_get_outputs_with_valid_inputs(self): + inputs = [np.array([1, 2, 3], dtype=np.float32)] + struct_outputs = np.array([1, 2, 3]) + + def mock_compile_model(inputs): + class MockResult: + def to_tuple(self): + return (np.array([4, 5, 6]),) + + return MockResult() + + result = get_outputs(inputs, struct_outputs, mock_compile_model) + self.assertIsInstance(result, np.ndarray) + self.assertAllClose(result, np.array([4, 5, 6])) + + def test_get_outputs_with_nested_inputs(self): + inputs = { + "a": np.array([1, 2], dtype=np.float32), + "b": [np.array([3, 4], dtype=np.int32)], + } + struct_outputs = [np.array([1, 2]), np.array([3, 4])] + + def mock_compile_model(inputs): + class MockResult: + def to_tuple(self): + return (np.array([5, 6]), np.array([7, 8])) + + return MockResult() + + result = get_outputs(inputs, struct_outputs, mock_compile_model) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + self.assertAllClose(result[0], np.array([5, 6])) + self.assertAllClose(result[1], np.array([7, 8])) diff --git a/keras_hub/src/utils/pipeline_model_test.py b/keras_hub/src/utils/pipeline_model_test.py index 101171597c..42178f54b8 100644 --- a/keras_hub/src/utils/pipeline_model_test.py +++ b/keras_hub/src/utils/pipeline_model_test.py @@ -2,6 +2,7 @@ import keras import numpy as np +import pytest import tensorflow as tf from keras_hub.src.tests.test_case import TestCase @@ -78,6 +79,7 @@ def from_config(cls, config): class TestNoopPipelineModel(TestCase): + @pytest.mark.requires_trainable_backend def test_fit(self): x = np.random.uniform(size=(8, 5)) y = np.random.uniform(size=(8, 1)) @@ -111,6 +113,7 @@ def test_predict(self): model.predict(x=x, batch_size=8) model.predict(tf.data.Dataset.from_tensor_slices(x).batch(8)) + @pytest.mark.requires_trainable_backend def test_on_batch(self): x = np.random.uniform(size=(8, 5)) y = np.random.uniform(size=(8, 1)) @@ -143,6 +146,7 @@ def test_saved_model(self): class TestFeaturePreprocessingModel(TestCase): + @pytest.mark.requires_trainable_backend def test_fit_with_preprocessing(self): x = tf.strings.as_string(np.random.uniform(size=(100, 5))) y = np.random.uniform(size=(100, 1)) @@ -176,6 +180,7 @@ def test_predict_with_preprocessing(self): model.predict(x=x, batch_size=8) model.predict(tf.data.Dataset.from_tensor_slices(x).batch(8)) + @pytest.mark.requires_trainable_backend def test_on_batch(self): x = tf.strings.as_string(np.random.uniform(size=(8, 5))) y = np.random.uniform(size=(8, 1)) @@ -208,6 +213,7 @@ def test_saved_model(self): class TestLabelPreprocessingModel(TestCase): + @pytest.mark.requires_trainable_backend def test_fit_with_preprocessing(self): x = np.random.uniform(size=(100, 5)) y = tf.strings.as_string(np.random.uniform(size=(100, 1))) @@ -273,6 +279,7 @@ def test_saved_model(self): class TestDataPreprocessingModel(TestCase): + @pytest.mark.requires_trainable_backend def test_fit_with_preprocessing(self): data = tf.strings.as_string(np.random.uniform(size=(100, 1))) model = DataPipeline() @@ -324,6 +331,7 @@ def test_saved_model(self): class TestFunctional(TestCase): + @pytest.mark.requires_trainable_backend def test_fit(self): x = tf.strings.as_string(np.random.uniform(size=(100, 5))) y = np.random.uniform(size=(100, 1)) @@ -355,6 +363,7 @@ def test_saved_model(self): self.assertAllClose(model_output, restored_output) +@pytest.mark.requires_trainable_backend class TestFitArguments(TestCase): def test_validation_data(self): x = tf.strings.as_string(np.random.uniform(size=(80, 5))) @@ -400,6 +409,7 @@ def test_error_dataset_and_invalid_arguments(self): model.fit(ds, sample_weight=sw) +@pytest.mark.requires_trainable_backend class TestInputErrors(TestCase): def test_unbatched_input_raises(self): model = FeaturePipeline() diff --git a/openvino_excluded_concrete_tests.txt b/openvino_excluded_concrete_tests.txt new file mode 100644 index 0000000000..0d1c851933 --- /dev/null +++ b/openvino_excluded_concrete_tests.txt @@ -0,0 +1,24 @@ +AnchorGeneratorTest::test_anchor_generator0 +BoxMatcherTest::test_box_matcher_batched +BoxMatcherTest::test_box_matcher_empty_gt_boxes +BoxMatcherTest::test_box_matcher_force_match +BoxMatcherTest::test_box_matcher_unbatched +CachedMultiHeadAttentionTest::test_cache_call_is_correct +CachedMultiHeadAttentionTest::test_layer_behaviors +GemmaCausalLMTest::test_score_loss +GPT2CausalLMTest::test_score_loss +MistralCausalLMTest::test_score_loss +NonMaxSupressionTest::test_confidence_threshold +NonMaxSupressionTest::test_max_detections +RandomSamplerTest::test_early_stopping +RandomSamplerTest::test_stateful_call +ReversibleEmbeddingTest::test_quantize_dtype_argument_untie_weights +ReversibleEmbeddingTest::test_quantize_dtype_argument_tie_weights +ReversibleEmbeddingTest::test_quantize_int8_tie_weights +ReversibleEmbeddingTest::test_quantize_int8_untie_weights +ReversibleEmbeddingTest::test_saving_tie_weights +ReversibleEmbeddingTest::test_saving_untie_weights +TestNoopPipelineModel::test_evaluate +TestFeaturePreprocessingModel::test_evaluate_with_preprocessing +TestLabelPreprocessingModel::test_evaluate_with_preprocessing +TestDataPreprocessingModel::test_evaluate_with_preprocessing diff --git a/openvino_excluded_tests.txt b/openvino_excluded_tests.txt new file mode 100644 index 0000000000..f33dd5def7 --- /dev/null +++ b/openvino_excluded_tests.txt @@ -0,0 +1,75 @@ +keras_hub/src/layers/modeling/transformer_decoder_test.py +keras_hub/src/layers/modeling/transformer_encoder_test.py +keras_hub/src/layers/preprocessing/image_converter_test.py +keras_hub/src/metrics/bleu_test.py +keras_hub/src/metrics/edit_distance_test.py +keras_hub/src/metrics/perplexity_test.py +keras_hub/src/metrics/rouge_l_test.py +keras_hub/src/metrics/rouge_n_test.py +keras_hub/src/models/albert +keras_hub/src/models/audio_to_text_preprocessor_test.py +keras_hub/src/models/backbone_test.py +keras_hub/src/models/bart +keras_hub/src/models/basnet +keras_hub/src/models/bert +keras_hub/src/models/bloom +keras_hub/src/models/causal_lm_preprocessor_test.py +keras_hub/src/models/clip +keras_hub/src/models/cspnet +keras_hub/src/models/deberta_v3 +keras_hub/src/models/deeplab_v3 +keras_hub/src/models/deit +keras_hub/src/models/densenet +keras_hub/src/models/dinov2 +keras_hub/src/models/distil_bert +keras_hub/src/models/efficientnet +keras_hub/src/models/electra +keras_hub/src/models/falcon +keras_hub/src/models/flux +keras_hub/src/models/f_net +keras_hub/src/models/gemma3 +keras_hub/src/models/gpt_neo_x +keras_hub/src/models/hgnetv2 +keras_hub/src/models/llama3 +keras_hub/src/models/llama +keras_hub/src/models/masked_lm_preprocessor_test.py +keras_hub/src/models/mit +keras_hub/src/models/mixtral +keras_hub/src/models/mobilenet +keras_hub/src/models/moonshine +keras_hub/src/models/opt +keras_hub/src/models/pali_gemma +keras_hub/src/models/phi3 +keras_hub/src/models/preprocessor_test.py +keras_hub/src/models/qwen3 +keras_hub/src/models/qwen_moe +keras_hub/src/models/qwen +keras_hub/src/models/resnet +keras_hub/src/models/retinanet +keras_hub/src/models/roberta +keras_hub/src/models/roformer_v2 +keras_hub/src/models/sam +keras_hub/src/models/segformer +keras_hub/src/models/seq_2_seq_lm_preprocessor_test.py +keras_hub/src/models/siglip +keras_hub/src/models/stable_diffusion_3 +keras_hub/src/models/t5 +keras_hub/src/models/task_test.py +keras_hub/src/models/text_classifier_preprocessor_test.py +keras_hub/src/models/text_to_image_preprocessor_test.py +keras_hub/src/models/vae +keras_hub/src/models/vgg +keras_hub/src/models/vit_det +keras_hub/src/models/vit +keras_hub/src/models/vit +keras_hub/src/models/whisper +keras_hub/src/models/xception +keras_hub/src/models/xlm_roberta +keras_hub/src/models/xlnet +keras_hub/src/samplers/beam_sampler_test.py +keras_hub/src/samplers/contrastive_sampler_test.py +keras_hub/src/samplers/greedy_sampler_test.py +keras_hub/src/samplers/top_k_sampler_test.py +keras_hub/src/samplers/top_p_sampler_test.py +keras_hub/src/utils/pipeline_model_test.py +keras_hub/src/utils/transformers/export/gemma_test.py \ No newline at end of file diff --git a/requirements-common.txt b/requirements-common.txt index a98ed71301..a258d1cd85 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -18,4 +18,5 @@ sentencepiece tensorflow-datasets safetensors pillow +openvino transformers From 9305b053fba4203720c6402935f40b53cf792093 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Sun, 3 Aug 2025 14:36:15 +0300 Subject: [PATCH 02/17] enable test_cache test --- keras_hub/src/models/gemma/gemma_causal_lm_test.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/keras_hub/src/models/gemma/gemma_causal_lm_test.py b/keras_hub/src/models/gemma/gemma_causal_lm_test.py index 88ea2174ff..59a40d3b3b 100644 --- a/keras_hub/src/models/gemma/gemma_causal_lm_test.py +++ b/keras_hub/src/models/gemma/gemma_causal_lm_test.py @@ -61,14 +61,6 @@ def test_causal_lm_basics(self): expected_output_shape=(2, 8, 11), ) - # Note: To enable this test for OpenVINO, - # the issue causing long execution time must be resolved. - # See related discussion for details: - # https://github.com/openvinotoolkit/openvino/pull/31482 - @pytest.mark.skipif( - keras.config.backend() == "openvino", - reason="Skip for openvino it takes long time", - ) def test_cache_correctness(self): token_ids = self.input_data["token_ids"] padding_mask = ops.ones_like(self.input_data["padding_mask"]) From 91b478ff4b6937f4c6eeff34f85572ce2bc47a86 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Sun, 3 Aug 2025 14:54:59 +0300 Subject: [PATCH 03/17] update conftest --- conftest.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/conftest.py b/conftest.py index f5613930f8..a719bbafdf 100644 --- a/conftest.py +++ b/conftest.py @@ -2,7 +2,6 @@ import keras import pytest -from keras.src.backend import backend def pytest_addoption(parser): @@ -116,7 +115,7 @@ def pytest_collection_modifyitems(config, items): item.add_marker(kaggle_key_required) openvino_skipped_tests = [] - if backend() == "openvino": + if keras.config.backend() == "openvino": from pathlib import Path workspace_root = Path(__file__).resolve().parents[0] @@ -127,8 +126,8 @@ def pytest_collection_modifyitems(config, items): ] requires_trainable_backend = pytest.mark.skipif( - backend() in ["openvino"], - reason="Trainer not implemented for OpenVINO backend.", + keras.config.backend() in ["openvino"], + reason="fit not implemented for OpenVINO backend.", ) for item in items: @@ -149,7 +148,9 @@ def pytest_collection_modifyitems(config, items): def skip_if_backend(given_backend, reason): - return pytest.mark.skipif(backend() == given_backend, reason=reason) + return pytest.mark.skipif( + keras.config.backend() == given_backend, reason=reason + ) # Disable traceback filtering for quicker debugging of tests failures. From 23e62e2bd9e830a77cdd0da3083c8ce856ce34b1 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Sun, 3 Aug 2025 14:58:19 +0300 Subject: [PATCH 04/17] update causal.lm --- keras_hub/src/models/causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index 30b8198eb5..21433d08e2 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -178,7 +178,7 @@ def wrapped_generate_function(inputs, stop_token_ids=None): return ov_infer(inputs, stop_token_ids, self.generate_step) self.generate_function = wrapped_generate_function - elif keras.config.backend() == "torch": + if keras.config.backend() == "torch": import torch def wrapped_generate_function( From c9291e0fc4b248fb4f305807b3ff4e247c39a270 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Sun, 3 Aug 2025 16:28:20 +0300 Subject: [PATCH 05/17] remove openvino_utils and handle device --- keras_hub/src/models/causal_lm.py | 29 +++- keras_hub/src/utils/openvino_utils.py | 126 -------------- keras_hub/src/utils/openvino_utils_test.py | 192 --------------------- 3 files changed, 22 insertions(+), 325 deletions(-) delete mode 100644 keras_hub/src/utils/openvino_utils.py delete mode 100644 keras_hub/src/utils/openvino_utils_test.py diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index 21433d08e2..12a7be775d 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -136,13 +136,23 @@ def make_generate_function(self): import openvino as ov import openvino.runtime.opset14 as ov_opset - from keras_hub.src.utils.openvino_utils import get_outputs - from keras_hub.src.utils.openvino_utils import get_struct_outputs - def ov_infer(inputs, stop_token_ids, fn): - struct_params, struct_outputs = get_struct_outputs( - inputs, stop_token_ids, fn - ) + def get_outputs(inputs, struct_outputs, compile_ov_model): + flatten_inputs = tree.flatten(inputs) + for input in flatten_inputs: + if ops.is_tensor(input): + raise ValueError("inputs should be numpy arrays") + outputs = compile_ov_model(flatten_inputs) + outputs = self._unpack_singleton( + tree.pack_sequence_as( + struct_outputs, outputs.to_tuple() + ) + ) + return outputs + + struct_params = self._parameterize_data(inputs) + struct_outputs = fn(struct_params, stop_token_ids) + # Try using the existing compiled model if self.ov_compiled_model is not None: try: @@ -150,8 +160,11 @@ def ov_infer(inputs, stop_token_ids, fn): inputs, struct_outputs, self.ov_compiled_model ) except Exception: + # Delete previous model, then # Fall through to recompilation if inference fails + del self.ov_compiled_model pass + # Rebuild and compile the OpenVINO model parameters = [ p.output.get_node() for p in tree.flatten(struct_params) @@ -168,7 +181,9 @@ def ov_infer(inputs, stop_token_ids, fn): ) ov_model.validate_nodes_and_infer_types() core = ov.Core() - self.ov_compiled_model = core.compile_model(ov_model, "CPU") + device = "CPU" + # OpenVINO supports only compiling with 'CPU' devices. + self.ov_compiled_model = core.compile_model(ov_model, device) return get_outputs( inputs, struct_outputs, self.ov_compiled_model ) diff --git a/keras_hub/src/utils/openvino_utils.py b/keras_hub/src/utils/openvino_utils.py deleted file mode 100644 index 3633bdf8cf..0000000000 --- a/keras_hub/src/utils/openvino_utils.py +++ /dev/null @@ -1,126 +0,0 @@ -import numpy as np -import openvino as ov -import openvino.runtime.opset14 as ov_opset -from keras import ops -from keras import tree - -OPENVINO_DTYPES = { - "float16": ov.Type.f16, - "float32": ov.Type.f32, - "float64": ov.Type.f64, - "uint8": ov.Type.u8, - "uint16": ov.Type.u16, - "uint32": ov.Type.u32, - "uint64": ov.Type.u64, - "int8": ov.Type.i8, - "int16": ov.Type.i16, - "int32": ov.Type.i32, - "int64": ov.Type.i64, - "bfloat16": ov.Type.bf16, - "bool": ov.Type.boolean, - "float8_e4m3fn": ov.Type.f8e4m3, - "float8_e5m2": ov.Type.f8e5m2, - "string": ov.Type.string, -} - - -def unpack_singleton(x): - if isinstance(x, (list, tuple)) and len(x) == 1: - return x[0] - return x - - -def parameterize_inputs(inputs, prefix=""): - """ - Recursively converts input structures (dict, list, tuple, or scalars) into - OpenVINO Parameter nodes, preserving structure and assigning friendly names. - - Args: - inputs (Union[dict, list, tuple, np.ndarray, int, float]): - Input data structure or value to parameterize. - prefix (str): Prefix for naming OpenVINO parameter nodes. - - Returns: - Structure of the same form as `inputs`, but with each input replaced - by an OpenVINO-compatible tensor (converted parameter). - - Raises: - TypeError: If the input type is not supported. - """ - if isinstance(inputs, (list, tuple)): - return [ - parameterize_inputs(e, f"{prefix}{i}") for i, e in enumerate(inputs) - ] - elif isinstance(inputs, dict): - return {k: parameterize_inputs(v, k) for k, v in inputs.items()} - elif isinstance(inputs, np.ndarray): - ov_type = OPENVINO_DTYPES[str(inputs.dtype)] - ov_shape = list(inputs.shape) - param = ov_opset.parameter(shape=ov_shape, dtype=ov_type) - param.set_friendly_name(prefix) - return ops.convert_to_tensor(param.output(0)) - elif isinstance(inputs, (int, np.integer)): - param = ov_opset.parameter(shape=[], dtype=ov.Type.i32) - param.set_friendly_name(prefix) - return ops.convert_to_tensor(param.output(0)) - elif isinstance(inputs, (float, np.floating)): - param = ov_opset.parameter(shape=[], dtype=ov.Type.f32) - param.set_friendly_name(prefix) - return ops.convert_to_tensor(param.output(0)) - else: - raise TypeError(f"Unknown input type: {type(inputs)}") - - -def get_struct_outputs(inputs, stop_token_ids, fn): - """ - Prepares OpenVINO input parameters and calls the - user-defined generation function. - - Args: - inputs (dict or nested structure): Original input data - stop_token_ids (Any): Stop token information passed to - the model's generation step. - fn (Callable): A function representing a single generation - step that accepts parameterized inputs and returns structured outputs. - - Returns: - Tuple: (parameterized_inputs, struct_outputs) - - parameterized_inputs: OpenVINO parameter structure - for model compilation. - - struct_outputs: The output structure returned by - the generation function. - """ - struct_params = parameterize_inputs(inputs) - struct_outputs = fn(struct_params, stop_token_ids) - return struct_params, struct_outputs - - -def get_outputs(inputs, struct_outputs, compile_ov_model): - """ - Executes the OpenVINO compiled model with the given - inputs and reconstructs the output structure - to match `struct_outputs`. - - Args: - inputs (dict or nested structure): Original input data. - struct_outputs (Any): The structure that defines - how to reconstruct model outputs. - compile_ov_model (Callable): The compiled OpenVINO - model object with a `__call__` method. - - Returns: - The model output reconstructed to - match the structure of `struct_outputs`. - - Raises: - ValueError: If any of the inputs are still tensors. - """ - flatten_inputs = tree.flatten(inputs) - for input in flatten_inputs: - if ops.is_tensor(input): - raise ValueError("inputs should be numpy arrays") - outputs = compile_ov_model(flatten_inputs) - outputs = unpack_singleton( - tree.pack_sequence_as(struct_outputs, outputs.to_tuple()) - ) - return outputs diff --git a/keras_hub/src/utils/openvino_utils_test.py b/keras_hub/src/utils/openvino_utils_test.py deleted file mode 100644 index 5d52c15282..0000000000 --- a/keras_hub/src/utils/openvino_utils_test.py +++ /dev/null @@ -1,192 +0,0 @@ -import numpy as np -import openvino as ov -import pytest -from keras import backend -from keras import ops - -from keras_hub.src.tests.test_case import TestCase -from keras_hub.src.utils.openvino_utils import OPENVINO_DTYPES -from keras_hub.src.utils.openvino_utils import get_outputs -from keras_hub.src.utils.openvino_utils import get_struct_outputs -from keras_hub.src.utils.openvino_utils import parameterize_inputs -from keras_hub.src.utils.openvino_utils import unpack_singleton - - -@pytest.mark.skipif( - backend.backend() != "openvino", - reason="OpenVINO is required for these tests", -) -class TestOpenVinoUtils(TestCase): - def test_openvino_dtypes_mapping(self): - self.assertIn("float32", OPENVINO_DTYPES) - self.assertIn("int32", OPENVINO_DTYPES) - self.assertIn("bool", OPENVINO_DTYPES) - self.assertEqual(OPENVINO_DTYPES["float32"], ov.Type.f32) - self.assertEqual(OPENVINO_DTYPES["int32"], ov.Type.i32) - self.assertEqual(OPENVINO_DTYPES["bool"], ov.Type.boolean) - - def test_unpack_singleton_single_element_list(self): - result = unpack_singleton([42]) - self.assertEqual(result, 42) - - def test_unpack_singleton_single_element_tuple(self): - result = unpack_singleton((42,)) - self.assertEqual(result, 42) - - def test_unpack_singleton_multiple_elements(self): - input_list = [1, 2, 3] - result = unpack_singleton(input_list) - self.assertEqual(result, input_list) - - def test_unpack_singleton_empty_list(self): - input_list = [] - result = unpack_singleton(input_list) - self.assertEqual(result, input_list) - - def test_unpack_singleton_non_sequence(self): - result = unpack_singleton(42) - self.assertEqual(result, 42) - - def test_parameterize_inputs_numpy_array(self): - input_array = np.array([1, 2, 3], dtype=np.float32) - result = parameterize_inputs(input_array) - self.assertTrue(ops.is_tensor(result)) - - def test_parameterize_inputs_different_dtypes(self): - test_cases = [ - (np.array([1, 2, 3], dtype=np.int32), np.int32), - (np.array([1.0, 2.0, 3.0], dtype=np.float32), np.float32), - (np.array([1, 2, 3], dtype=np.int64), np.int64), - ] - - for input_array, expected_dtype in test_cases: - result = parameterize_inputs(input_array) - self.assertTrue(ops.is_tensor(result)) - - def test_parameterize_inputs_list(self): - input_list = [ - np.array([1, 2, 3], dtype=np.float32), - np.array([4, 5, 6], dtype=np.int32), - ] - result = parameterize_inputs(input_list) - self.assertIsInstance(result, list) - self.assertEqual(len(result), 2) - self.assertTrue(all(ops.is_tensor(r) for r in result)) - - def test_parameterize_inputs_tuple(self): - input_tuple = ( - np.array([1, 2, 3], dtype=np.float32), - np.array([4, 5, 6], dtype=np.int32), - ) - result = parameterize_inputs(input_tuple) - self.assertIsInstance(result, list) - self.assertEqual(len(result), 2) - self.assertTrue(all(ops.is_tensor(r) for r in result)) - - def test_parameterize_inputs_dict(self): - input_dict = { - "a": np.array([1, 2, 3], dtype=np.float32), - "b": np.array([4, 5, 6], dtype=np.int32), - } - result = parameterize_inputs(input_dict) - self.assertIsInstance(result, dict) - self.assertEqual(set(result.keys()), {"a", "b"}) - self.assertTrue(all(ops.is_tensor(v) for v in result.values())) - - def test_parameterize_inputs_integer(self): - result = parameterize_inputs(42) - self.assertTrue(ops.is_tensor(result)) - - def test_parameterize_inputs_numpy_integer(self): - result = parameterize_inputs(np.int32(42)) - self.assertTrue(ops.is_tensor(result)) - - def test_parameterize_inputs_float(self): - result = parameterize_inputs(3.14) - self.assertTrue(ops.is_tensor(result)) - - def test_parameterize_inputs_numpy_float(self): - result = parameterize_inputs(np.float32(3.14)) - self.assertTrue(ops.is_tensor(result)) - - def test_parameterize_inputs_unsupported_type(self): - with self.assertRaisesRegex(TypeError, "Unknown input type"): - parameterize_inputs("unsupported_string") - - def test_parameterize_inputs_nested_structure(self): - """Test parameterizing nested structures.""" - nested_input = { - "list": [np.array([1, 2], dtype=np.float32), 42], - "dict": {"nested": np.array([3, 4], dtype=np.int32)}, - } - result = parameterize_inputs(nested_input) - self.assertIsInstance(result, dict) - self.assertIsInstance(result["list"], list) - self.assertIsInstance(result["dict"], dict) - self.assertTrue(ops.is_tensor(result["list"][0])) - self.assertTrue(ops.is_tensor(result["list"][1])) - self.assertTrue(ops.is_tensor(result["dict"]["nested"])) - - def test_get_struct_outputs(self): - inputs = np.array([1, 2, 3], dtype=np.float32) - stop_token_ids = [0, 1] - - def mock_fn(params, stop_tokens): - return params # Simple mock that returns the params - - struct_params, struct_outputs = get_struct_outputs( - inputs, stop_token_ids, mock_fn - ) - self.assertTrue(ops.is_tensor(struct_params)) - self.assertTrue(ops.is_tensor(struct_outputs)) - - def test_get_outputs_with_tensor_input_raises_error(self): - inputs = [ops.convert_to_tensor(np.array([1, 2, 3]))] - struct_outputs = np.array([1, 2, 3]) - - def mock_compile_model(inputs): - class MockResult: - def to_tuple(self): - return (np.array([1, 2, 3]),) - - return MockResult() - - with self.assertRaisesRegex( - ValueError, "inputs should be numpy arrays" - ): - get_outputs(inputs, struct_outputs, mock_compile_model) - - def test_get_outputs_with_valid_inputs(self): - inputs = [np.array([1, 2, 3], dtype=np.float32)] - struct_outputs = np.array([1, 2, 3]) - - def mock_compile_model(inputs): - class MockResult: - def to_tuple(self): - return (np.array([4, 5, 6]),) - - return MockResult() - - result = get_outputs(inputs, struct_outputs, mock_compile_model) - self.assertIsInstance(result, np.ndarray) - self.assertAllClose(result, np.array([4, 5, 6])) - - def test_get_outputs_with_nested_inputs(self): - inputs = { - "a": np.array([1, 2], dtype=np.float32), - "b": [np.array([3, 4], dtype=np.int32)], - } - struct_outputs = [np.array([1, 2]), np.array([3, 4])] - - def mock_compile_model(inputs): - class MockResult: - def to_tuple(self): - return (np.array([5, 6]), np.array([7, 8])) - - return MockResult() - - result = get_outputs(inputs, struct_outputs, mock_compile_model) - self.assertIsInstance(result, list) - self.assertEqual(len(result), 2) - self.assertAllClose(result[0], np.array([5, 6])) - self.assertAllClose(result[1], np.array([7, 8])) From af2ae3315be098a8186a952a33400d72b033da35 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Sun, 3 Aug 2025 17:34:17 +0300 Subject: [PATCH 06/17] fix typo --- keras_hub/src/models/causal_lm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index 12a7be775d..16cfa4f51a 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -137,12 +137,12 @@ def make_generate_function(self): import openvino.runtime.opset14 as ov_opset def ov_infer(inputs, stop_token_ids, fn): - def get_outputs(inputs, struct_outputs, compile_ov_model): + def get_outputs(inputs, struct_outputs, compiled_ov_model): flatten_inputs = tree.flatten(inputs) for input in flatten_inputs: if ops.is_tensor(input): raise ValueError("inputs should be numpy arrays") - outputs = compile_ov_model(flatten_inputs) + outputs = compiled_ov_model(flatten_inputs) outputs = self._unpack_singleton( tree.pack_sequence_as( struct_outputs, outputs.to_tuple() From 792273ee083ba85e64958f636be80e164bc17b23 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Mon, 4 Aug 2025 00:30:12 +0300 Subject: [PATCH 07/17] remove unnecessary check --- keras_hub/src/models/causal_lm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index 16cfa4f51a..eab3a187cb 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -139,9 +139,6 @@ def make_generate_function(self): def ov_infer(inputs, stop_token_ids, fn): def get_outputs(inputs, struct_outputs, compiled_ov_model): flatten_inputs = tree.flatten(inputs) - for input in flatten_inputs: - if ops.is_tensor(input): - raise ValueError("inputs should be numpy arrays") outputs = compiled_ov_model(flatten_inputs) outputs = self._unpack_singleton( tree.pack_sequence_as( From 8baea81c446b5cf422edff8290e4e17659edd3ee Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Mon, 4 Aug 2025 15:31:31 +0300 Subject: [PATCH 08/17] update causal.lm --- keras_hub/src/models/causal_lm.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index eab3a187cb..be7a8ec982 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -133,9 +133,14 @@ def make_generate_function(self): self.generate_function = self.generate_step if keras.config.backend() == "openvino": + import os + + os.environ["OV_ENABLE_EINSUM_DECOMPOSITION"] = "1" import openvino as ov import openvino.runtime.opset14 as ov_opset + from keras_hub.src.utils.keras_utils import print_msg + def ov_infer(inputs, stop_token_ids, fn): def get_outputs(inputs, struct_outputs, compiled_ov_model): flatten_inputs = tree.flatten(inputs) @@ -147,28 +152,33 @@ def get_outputs(inputs, struct_outputs, compiled_ov_model): ) return outputs - struct_params = self._parameterize_data(inputs) - struct_outputs = fn(struct_params, stop_token_ids) - # Try using the existing compiled model if self.ov_compiled_model is not None: try: return get_outputs( - inputs, struct_outputs, self.ov_compiled_model + inputs, self.struct_outputs, self.ov_compiled_model ) - except Exception: - # Delete previous model, then + except RuntimeError as e: + # Delete previous model and struct outputs, then # Fall through to recompilation if inference fails + print_msg( + "WARNING: OpenVINO inference \033[1mFAILED\033[0m, " + "so we'll Rebuild and compile the model then " + f"try again.\n{e}" + ) del self.ov_compiled_model + del self.struct_outputs pass # Rebuild and compile the OpenVINO model + struct_params = self._parameterize_data(inputs) + self.struct_outputs = fn(struct_params, stop_token_ids) parameters = [ p.output.get_node() for p in tree.flatten(struct_params) ] results = [ ov_opset.result(r.output) - for r in tree.flatten(struct_outputs) + for r in tree.flatten(self.struct_outputs) ] ov_model = ov.Model(results=results, parameters=parameters) for ov_input in ov_model.inputs: @@ -182,10 +192,11 @@ def get_outputs(inputs, struct_outputs, compiled_ov_model): # OpenVINO supports only compiling with 'CPU' devices. self.ov_compiled_model = core.compile_model(ov_model, device) return get_outputs( - inputs, struct_outputs, self.ov_compiled_model + inputs, self.struct_outputs, self.ov_compiled_model ) def wrapped_generate_function(inputs, stop_token_ids=None): + # ops.array converts yo numpy in openvino backend inputs = tree.map_structure(ops.array, inputs) return ov_infer(inputs, stop_token_ids, self.generate_step) From 07985d68027cafe2c399fae5f5112f2311649905 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Wed, 6 Aug 2025 15:22:03 +0300 Subject: [PATCH 09/17] finalize PR --- .github/workflows/actions.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index e3620d50c2..da03b406c8 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -24,7 +24,6 @@ jobs: - backend: jax version: keras-nightly - backend: openvino - version: keras-stable python-version: '3.10' runs-on: ubuntu-latest env: @@ -53,7 +52,7 @@ jobs: pip install --no-deps -e "." --progress-bar off if [[ "${{ matrix.backend }}" == "openvino" ]]; then pip uninstall -y keras - pip install git+https://github.com/Mohamed-Ashraf273/keras.git@gsoc2025 --upgrade --force-reinstall --progress-bar off + pip install git+https://github.com/keras-team/keras.git --upgrade --force-reinstall --progress-bar off fi - name: Pin Keras 3.5 if: ${{ matrix.version == 'keras-3.5'}} From 57a92486e377d4c7e56af3e1938703c8480e035b Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Wed, 6 Aug 2025 16:38:33 +0300 Subject: [PATCH 10/17] optimize memory allocation inference --- keras_hub/src/models/causal_lm.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index be7a8ec982..fa34f8c826 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -134,21 +134,42 @@ def make_generate_function(self): self.generate_function = self.generate_step if keras.config.backend() == "openvino": import os + from multiprocessing import Pipe + from multiprocessing import Process - os.environ["OV_ENABLE_EINSUM_DECOMPOSITION"] = "1" import openvino as ov import openvino.runtime.opset14 as ov_opset from keras_hub.src.utils.keras_utils import print_msg + os.environ["OV_ENABLE_EINSUM_DECOMPOSITION"] = "1" + def ov_infer(inputs, stop_token_ids, fn): + def isolated_infer(pipe, compiled_model, flat_inputs): + infer_request = compiled_model.create_infer_request() + outputs = infer_request.infer(flat_inputs) + numpy_outputs = outputs.to_tuple() + pipe.send(numpy_outputs) + pipe.close() + del infer_request + def get_outputs(inputs, struct_outputs, compiled_ov_model): flatten_inputs = tree.flatten(inputs) - outputs = compiled_ov_model(flatten_inputs) + parent_conn, child_conn = Pipe() + # Running inference in a separate process to avoid + # allocating unnecessary memory in the main process + # by OpenVINO inference_request. + # This saves 0.5 GB to 1 GB (depends on the model) + # of memory, but this makes latency increases by 1-2 seconds + p = Process( + target=isolated_infer, + args=(child_conn, compiled_ov_model, flatten_inputs), + ) + p.start() + outputs = parent_conn.recv() + p.join() outputs = self._unpack_singleton( - tree.pack_sequence_as( - struct_outputs, outputs.to_tuple() - ) + tree.pack_sequence_as(struct_outputs, outputs) ) return outputs From d2544feb0eb35eb7d6f34afdc54460997704cb6f Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Wed, 6 Aug 2025 22:53:02 +0300 Subject: [PATCH 11/17] optimize mem usage --- keras_hub/src/models/causal_lm.py | 59 +++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index fa34f8c826..eefd7c47e4 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -139,6 +139,7 @@ def make_generate_function(self): import openvino as ov import openvino.runtime.opset14 as ov_opset + import psutil from keras_hub.src.utils.keras_utils import print_msg @@ -146,28 +147,50 @@ def make_generate_function(self): def ov_infer(inputs, stop_token_ids, fn): def isolated_infer(pipe, compiled_model, flat_inputs): - infer_request = compiled_model.create_infer_request() - outputs = infer_request.infer(flat_inputs) - numpy_outputs = outputs.to_tuple() - pipe.send(numpy_outputs) + outputs = compiled_model(flat_inputs) + outputs = outputs.to_tuple() + pipe.send(outputs) pipe.close() - del infer_request def get_outputs(inputs, struct_outputs, compiled_ov_model): flatten_inputs = tree.flatten(inputs) - parent_conn, child_conn = Pipe() - # Running inference in a separate process to avoid - # allocating unnecessary memory in the main process - # by OpenVINO inference_request. - # This saves 0.5 GB to 1 GB (depends on the model) - # of memory, but this makes latency increases by 1-2 seconds - p = Process( - target=isolated_infer, - args=(child_conn, compiled_ov_model, flatten_inputs), + free_mem = psutil.virtual_memory().available / (1024**3) + # On average OpenVINO needs about 2 GB to run + # an inference, also it is wrapped by an env var, + # to be tuned. + threshold = float( + os.getenv("OV_INFER_FREE_MEM_THRESHOLD", 2) ) - p.start() - outputs = parent_conn.recv() - p.join() + if free_mem > threshold: + """Run inference in a separate process only if + free memory usage is above a certain threshold. + This threshold is calculated to ensure that + swap memory won't be triggered. When swap is + likely to be used, fallback to normal inference + to avoid severe performance degradation. + Running inference in a subprocess prevents OpenVINO from + allocating extra memory in the main process during its + internal infer request creation. This can reduce memory + usage by 0.5–2 GB depending on the model size. + However, using a subprocess introduces an extra + overhead, increasing latency by around 1–2 seconds + per inference. + """ + parent_conn, child_conn = Pipe() + p = Process( + target=isolated_infer, + args=( + child_conn, + compiled_ov_model, + flatten_inputs, + ), + ) + p.start() + outputs = parent_conn.recv() + p.join() + else: + outputs = compiled_ov_model(flatten_inputs) + outputs = outputs.to_tuple() outputs = self._unpack_singleton( tree.pack_sequence_as(struct_outputs, outputs) ) @@ -217,7 +240,7 @@ def get_outputs(inputs, struct_outputs, compiled_ov_model): ) def wrapped_generate_function(inputs, stop_token_ids=None): - # ops.array converts yo numpy in openvino backend + # ops.array converts to numpy in openvino backend inputs = tree.map_structure(ops.array, inputs) return ov_infer(inputs, stop_token_ids, self.generate_step) From cc3d6479756b05f8ce4affb22f74da5d6cb322b6 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Sat, 9 Aug 2025 02:27:56 +0300 Subject: [PATCH 12/17] remove env --- keras_hub/src/models/causal_lm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index eefd7c47e4..e2827b7a61 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -143,8 +143,6 @@ def make_generate_function(self): from keras_hub.src.utils.keras_utils import print_msg - os.environ["OV_ENABLE_EINSUM_DECOMPOSITION"] = "1" - def ov_infer(inputs, stop_token_ids, fn): def isolated_infer(pipe, compiled_model, flat_inputs): outputs = compiled_model(flat_inputs) From 873716b9d2291ea19531b4ffea78a659cfd93e7a Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Sun, 10 Aug 2025 16:37:51 +0300 Subject: [PATCH 13/17] update causal.lm --- keras_hub/src/models/causal_lm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index e2827b7a61..b55a0d2a3b 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -210,7 +210,6 @@ def get_outputs(inputs, struct_outputs, compiled_ov_model): ) del self.ov_compiled_model del self.struct_outputs - pass # Rebuild and compile the OpenVINO model struct_params = self._parameterize_data(inputs) From c186f3359ce12a214cf10beaef3a7ee2490b22e9 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Wed, 13 Aug 2025 23:03:10 +0300 Subject: [PATCH 14/17] fix errors --- .github/workflows/actions.yml | 2 - keras_hub/src/models/causal_lm.py | 73 +++++++++---------------------- openvino_excluded_tests.txt | 1 + 3 files changed, 21 insertions(+), 55 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 05b11093a9..8801a74064 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -23,8 +23,6 @@ jobs: version: keras-3.8 - backend: jax version: keras-nightly - - backend: openvino - python-version: '3.10' runs-on: ubuntu-latest env: KERAS_BACKEND: ${{ matrix.backend }} diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index b55a0d2a3b..70267fad4e 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -133,69 +133,29 @@ def make_generate_function(self): self.generate_function = self.generate_step if keras.config.backend() == "openvino": - import os - from multiprocessing import Pipe - from multiprocessing import Process - import openvino as ov import openvino.runtime.opset14 as ov_opset - import psutil from keras_hub.src.utils.keras_utils import print_msg def ov_infer(inputs, stop_token_ids, fn): - def isolated_infer(pipe, compiled_model, flat_inputs): - outputs = compiled_model(flat_inputs) - outputs = outputs.to_tuple() - pipe.send(outputs) - pipe.close() - def get_outputs(inputs, struct_outputs, compiled_ov_model): flatten_inputs = tree.flatten(inputs) - free_mem = psutil.virtual_memory().available / (1024**3) - # On average OpenVINO needs about 2 GB to run - # an inference, also it is wrapped by an env var, - # to be tuned. - threshold = float( - os.getenv("OV_INFER_FREE_MEM_THRESHOLD", 2) - ) - if free_mem > threshold: - """Run inference in a separate process only if - free memory usage is above a certain threshold. - This threshold is calculated to ensure that - swap memory won't be triggered. When swap is - likely to be used, fallback to normal inference - to avoid severe performance degradation. - Running inference in a subprocess prevents OpenVINO from - allocating extra memory in the main process during its - internal infer request creation. This can reduce memory - usage by 0.5–2 GB depending on the model size. - However, using a subprocess introduces an extra - overhead, increasing latency by around 1–2 seconds - per inference. - """ - parent_conn, child_conn = Pipe() - p = Process( - target=isolated_infer, - args=( - child_conn, - compiled_ov_model, - flatten_inputs, - ), - ) - p.start() - outputs = parent_conn.recv() - p.join() - else: - outputs = compiled_ov_model(flatten_inputs) - outputs = outputs.to_tuple() + outputs = compiled_ov_model(flatten_inputs).to_tuple() outputs = self._unpack_singleton( tree.pack_sequence_as(struct_outputs, outputs) ) return outputs + core = ov.Core() + device = "GPU" if "GPU" in core.available_devices else "CPU" + # Try using the existing compiled model - if self.ov_compiled_model is not None: + if ( + self.ov_compiled_model is not None + and getattr(self, "ov_device", None) is not None + and device == self.ov_device + ): try: return get_outputs( inputs, self.struct_outputs, self.ov_compiled_model @@ -228,10 +188,17 @@ def get_outputs(inputs, struct_outputs, compiled_ov_model): ov.PartialShape([-1] * rank) ) ov_model.validate_nodes_and_infer_types() - core = ov.Core() - device = "CPU" - # OpenVINO supports only compiling with 'CPU' devices. - self.ov_compiled_model = core.compile_model(ov_model, device) + + self.ov_device = device + model_dtype = ( + "f16" + if self.dtype == "float16" or self.dtype == "bfloat16" + else "f32" + ) + config = {"INFERENCE_PRECISION_HINT": model_dtype} + self.ov_compiled_model = core.compile_model( + ov_model, device, config + ) return get_outputs( inputs, self.struct_outputs, self.ov_compiled_model ) diff --git a/openvino_excluded_tests.txt b/openvino_excluded_tests.txt index f33dd5def7..877d0a0ba3 100644 --- a/openvino_excluded_tests.txt +++ b/openvino_excluded_tests.txt @@ -24,6 +24,7 @@ keras_hub/src/models/dinov2 keras_hub/src/models/distil_bert keras_hub/src/models/efficientnet keras_hub/src/models/electra +keras_hub/src/models/esm keras_hub/src/models/falcon keras_hub/src/models/flux keras_hub/src/models/f_net From 9bac18f57a2afc250d21d6add83798eac9db5791 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Tue, 19 Aug 2025 20:40:04 +0300 Subject: [PATCH 15/17] update PR --- .github/workflows/actions.yml | 16 +- conftest.py | 67 ++- integration_tests/basic_usage_test.py | 4 - .../src/layers/modeling/alibi_bias_test.py | 2 - .../layers/modeling/anchor_generator_test.py | 1 - .../cached_multi_head_attention_test.py | 3 - .../src/layers/modeling/f_net_encoder_test.py | 3 - .../layers/modeling/masked_lm_head_test.py | 3 - .../modeling/position_embedding_test.py | 3 - .../modeling/reversible_embedding_test.py | 2 - .../layers/modeling/rotary_embedding_test.py | 3 - .../modeling/sine_position_encoding_test.py | 3 - .../token_and_position_embedding_test.py | 2 - .../modeling/transformer_decoder_test.py | 4 - .../modeling/transformer_encoder_test.py | 3 - keras_hub/src/models/causal_lm.py | 76 +--- .../src/models/gemma/gemma_backbone_test.py | 2 - keras_hub/src/models/gemma/gemma_causal_lm.py | 3 - .../src/models/gemma/gemma_causal_lm_test.py | 1 - keras_hub/src/models/gemma/gemma_lora_test.py | 2 - .../src/models/gpt2/gpt2_backbone_test.py | 1 - keras_hub/src/models/gpt2/gpt2_causal_lm.py | 3 - .../src/models/gpt2/gpt2_causal_lm_test.py | 1 - .../models/mistral/mistral_backbone_test.py | 1 - .../src/models/mistral/mistral_causal_lm.py | 3 - .../models/mistral/mistral_causal_lm_test.py | 1 - keras_hub/src/utils/openvino_utils.py | 311 ++++++++++++++ keras_hub/src/utils/openvino_utils_test.py | 386 ++++++++++++++++++ keras_hub/src/utils/pipeline_model_test.py | 10 - openvino_excluded_concrete_tests.txt | 24 -- openvino_excluded_tests.txt | 76 ---- openvino_supported_tests.txt | 21 + 32 files changed, 755 insertions(+), 286 deletions(-) create mode 100644 keras_hub/src/utils/openvino_utils.py create mode 100644 keras_hub/src/utils/openvino_utils_test.py delete mode 100644 openvino_excluded_concrete_tests.txt delete mode 100644 openvino_excluded_tests.txt create mode 100644 openvino_supported_tests.txt diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 8801a74064..33a3c5708b 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -16,13 +16,15 @@ jobs: strategy: fail-fast: false matrix: - backend: [tensorflow, jax, torch, openvino] + backend: [tensorflow, jax, torch] version: [keras-stable] include: - backend: jax version: keras-3.8 - backend: jax version: keras-nightly + - backend: openvino + version: keras-nightly runs-on: ubuntu-latest env: KERAS_BACKEND: ${{ matrix.backend }} @@ -48,10 +50,6 @@ jobs: run: | pip install -r requirements.txt --progress-bar off pip install --no-deps -e "." --progress-bar off - if [[ "${{ matrix.backend }}" == "openvino" ]]; then - pip uninstall -y keras - pip install git+https://github.com/keras-team/keras.git --upgrade --force-reinstall --progress-bar off - fi - name: Pin Keras 3.8 if: ${{ matrix.version == 'keras-3.8'}} run: | @@ -65,13 +63,7 @@ jobs: pip install keras-nightly --progress-bar off - name: Test with pytest run: | - if [ "${{ matrix.backend }}" == "openvino" ]; then - IGNORE_FILE="openvino_excluded_tests.txt" - IGNORE_ARGS=$(awk '{print "--ignore=" $0}' "$IGNORE_FILE") - else - IGNORE_ARGS="" - fi - pytest keras_hub/ $IGNORE_ARGS + pytest keras_hub/ - name: Run integration tests run: | python pip_build.py --install diff --git a/conftest.py b/conftest.py index a719bbafdf..3dd9323a04 100644 --- a/conftest.py +++ b/conftest.py @@ -1,8 +1,12 @@ import os +from pathlib import Path import keras import pytest +from keras_hub.src.utils.openvino_utils import get_openvino_skip_reason +from keras_hub.src.utils.openvino_utils import setup_openvino_test_config + def pytest_addoption(parser): parser.addoption( @@ -29,6 +33,13 @@ def pytest_addoption(parser): default=False, help="fail if a gpu is not present", ) + parser.addoption( + "--auto_skip_training", + action="store_true", + default=True, + help="automatically skip tests with " + "training methods on non-trainable backends", + ) def pytest_configure(config): @@ -70,16 +81,15 @@ def pytest_configure(config): "markers", "kaggle_key_required: mark test needing a kaggle key", ) - config.addinivalue_line( - "markers", - "requires_trainable_backend: mark test for trainable backend only", - ) def pytest_collection_modifyitems(config, items): + openvino_supported_paths = None + run_extra_large_tests = config.getoption("--run_extra_large") # Run large tests for --run_extra_large or --run_large. run_large_tests = config.getoption("--run_large") or run_extra_large_tests + auto_skip_training = config.getoption("--auto_skip_training") # Messages to annotate skipped tests with. skip_large = pytest.mark.skipif( @@ -114,43 +124,22 @@ def pytest_collection_modifyitems(config, items): if "kaggle_key_required" in item.keywords: item.add_marker(kaggle_key_required) - openvino_skipped_tests = [] - if keras.config.backend() == "openvino": - from pathlib import Path - - workspace_root = Path(__file__).resolve().parents[0] - file_path = workspace_root / "openvino_excluded_concrete_tests.txt" - with open(file_path, "r") as file: - openvino_skipped_tests = [ - line.strip() for line in file if line.strip() - ] - - requires_trainable_backend = pytest.mark.skipif( - keras.config.backend() in ["openvino"], - reason="fit not implemented for OpenVINO backend.", - ) - - for item in items: - if "requires_trainable_backend" in item.keywords: - item.add_marker(requires_trainable_backend) - # also, skip concrete tests for openvino, listed in the special file - # this is more granular mechanism to exclude tests rather - # than using --ignore option - for skipped_test in openvino_skipped_tests: - if skipped_test in item.nodeid: + # OpenVINO-specific skipping logic - whitelist-based approach + if keras.config.backend() == "openvino": + # OpenVINO backend configuration + if openvino_supported_paths is None: + openvino_supported_paths = setup_openvino_test_config( + str(Path(__file__).parent) + ) + skip_reason = get_openvino_skip_reason( + item, + openvino_supported_paths, + auto_skip_training, + ) + if skip_reason: item.add_marker( - skip_if_backend( - "openvino", - "Not supported operation by openvino backend", - ) + pytest.mark.skipif(True, reason=f"OpenVINO: {skip_reason}") ) - break - - -def skip_if_backend(given_backend, reason): - return pytest.mark.skipif( - keras.config.backend() == given_backend, reason=reason - ) # Disable traceback filtering for quicker debugging of tests failures. diff --git a/integration_tests/basic_usage_test.py b/integration_tests/basic_usage_test.py index 75af52d577..7fd73bb9e5 100644 --- a/integration_tests/basic_usage_test.py +++ b/integration_tests/basic_usage_test.py @@ -6,10 +6,6 @@ import keras_hub -@unittest.skipIf( - keras.backend.backend() == "openvino", - "Skip for non-trainable backends like OpenVINO", -) class BasicUsageTest(unittest.TestCase): def test_transformer(self): # Tokenize some inputs with a binary label. diff --git a/keras_hub/src/layers/modeling/alibi_bias_test.py b/keras_hub/src/layers/modeling/alibi_bias_test.py index 6e9d454472..627cede5a2 100644 --- a/keras_hub/src/layers/modeling/alibi_bias_test.py +++ b/keras_hub/src/layers/modeling/alibi_bias_test.py @@ -1,5 +1,4 @@ import keras -import pytest from keras import ops from keras import random @@ -8,7 +7,6 @@ class AlibiBiasTest(TestCase): - @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): alibi_bias_max = 8 batch_size = 4 diff --git a/keras_hub/src/layers/modeling/anchor_generator_test.py b/keras_hub/src/layers/modeling/anchor_generator_test.py index 4fd67a732e..e5918cdfda 100644 --- a/keras_hub/src/layers/modeling/anchor_generator_test.py +++ b/keras_hub/src/layers/modeling/anchor_generator_test.py @@ -14,7 +14,6 @@ reason="Bbox utils are not supported before keras < 3.8.0", ) class AnchorGeneratorTest(TestCase): - @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): images_shape = (8, 128, 128, 3) self.run_layer_test( diff --git a/keras_hub/src/layers/modeling/cached_multi_head_attention_test.py b/keras_hub/src/layers/modeling/cached_multi_head_attention_test.py index 7d41589a3c..6690667ade 100644 --- a/keras_hub/src/layers/modeling/cached_multi_head_attention_test.py +++ b/keras_hub/src/layers/modeling/cached_multi_head_attention_test.py @@ -1,4 +1,3 @@ -import pytest from keras import ops from keras import random @@ -9,7 +8,6 @@ class CachedMultiHeadAttentionTest(TestCase): - @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=CachedMultiHeadAttention, @@ -77,7 +75,6 @@ def call(outputs, cache): self.assertAllClose(output, no_loop_outputs) self.assertAllClose(output_cache, no_loop_cache) - @pytest.mark.requires_trainable_backend def test_training_propagation(self): batch_size = 2 seq_len = 5 diff --git a/keras_hub/src/layers/modeling/f_net_encoder_test.py b/keras_hub/src/layers/modeling/f_net_encoder_test.py index 32848a82a3..0fe2d361dd 100644 --- a/keras_hub/src/layers/modeling/f_net_encoder_test.py +++ b/keras_hub/src/layers/modeling/f_net_encoder_test.py @@ -1,4 +1,3 @@ -import pytest from keras import ops from keras import random @@ -7,7 +6,6 @@ class FNetEncoderTest(TestCase): - @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=FNetEncoder, @@ -33,7 +31,6 @@ def test_value_error_when_invalid_kernel_initializer(self): kernel_initializer="Invalid", ) - @pytest.mark.requires_trainable_backend def test_training_propagation(self): x = random.uniform(shape=(2, 4, 6)) layer = FNetEncoder( diff --git a/keras_hub/src/layers/modeling/masked_lm_head_test.py b/keras_hub/src/layers/modeling/masked_lm_head_test.py index 0fc26449a0..28da6dc6f9 100644 --- a/keras_hub/src/layers/modeling/masked_lm_head_test.py +++ b/keras_hub/src/layers/modeling/masked_lm_head_test.py @@ -1,4 +1,3 @@ -import pytest from keras import random from keras_hub.src.layers.modeling.masked_lm_head import MaskedLMHead @@ -9,7 +8,6 @@ class MaskedLMHeadTest(TestCase): - @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=MaskedLMHead, @@ -29,7 +27,6 @@ def test_layer_behaviors(self): expected_num_trainable_weights=6, ) - @pytest.mark.requires_trainable_backend def test_layer_behaviors_with_embedding(self): embedding = ReversibleEmbedding(100, 16) embedding.build((4, 10)) diff --git a/keras_hub/src/layers/modeling/position_embedding_test.py b/keras_hub/src/layers/modeling/position_embedding_test.py index ad4f4f5ca4..d6e577c66e 100644 --- a/keras_hub/src/layers/modeling/position_embedding_test.py +++ b/keras_hub/src/layers/modeling/position_embedding_test.py @@ -1,6 +1,5 @@ import keras import numpy as np -import pytest from keras import ops from keras import random @@ -16,7 +15,6 @@ def custom_init(shape, dtype=None): class PositionEmbeddingTest(TestCase): - @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=PositionEmbedding, @@ -28,7 +26,6 @@ def test_layer_behaviors(self): expected_num_trainable_weights=1, ) - @pytest.mark.requires_trainable_backend def test_layer_behaviors_4d(self): self.run_layer_test( cls=PositionEmbedding, diff --git a/keras_hub/src/layers/modeling/reversible_embedding_test.py b/keras_hub/src/layers/modeling/reversible_embedding_test.py index be7b69ccda..4482854449 100644 --- a/keras_hub/src/layers/modeling/reversible_embedding_test.py +++ b/keras_hub/src/layers/modeling/reversible_embedding_test.py @@ -2,7 +2,6 @@ import keras import numpy as np -import pytest from absl.testing import parameterized from keras import ops from keras import random @@ -18,7 +17,6 @@ class ReversibleEmbeddingTest(TestCase): ("tie_weights", True), ("untie_weights", False), ) - @pytest.mark.requires_trainable_backend def test_layer_behaviors_tied(self, tie_weights): self.run_layer_test( cls=ReversibleEmbedding, diff --git a/keras_hub/src/layers/modeling/rotary_embedding_test.py b/keras_hub/src/layers/modeling/rotary_embedding_test.py index 5434e8b477..35c2cc1356 100644 --- a/keras_hub/src/layers/modeling/rotary_embedding_test.py +++ b/keras_hub/src/layers/modeling/rotary_embedding_test.py @@ -1,6 +1,5 @@ import keras import numpy as np -import pytest from keras import ops from keras import random @@ -9,7 +8,6 @@ class RotaryEmbeddingTest(TestCase): - @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=RotaryEmbedding, @@ -23,7 +21,6 @@ def test_layer_behaviors(self): expected_output_shape=(2, 4, 6), ) - @pytest.mark.requires_trainable_backend def test_layer_behaviors_4d(self): self.run_layer_test( cls=RotaryEmbedding, diff --git a/keras_hub/src/layers/modeling/sine_position_encoding_test.py b/keras_hub/src/layers/modeling/sine_position_encoding_test.py index fd29c51fc1..eb0aeb2ff3 100644 --- a/keras_hub/src/layers/modeling/sine_position_encoding_test.py +++ b/keras_hub/src/layers/modeling/sine_position_encoding_test.py @@ -1,5 +1,4 @@ import keras -import pytest from keras import ops from keras import random @@ -10,7 +9,6 @@ class SinePositionEncodingTest(TestCase): - @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=SinePositionEncoding, @@ -21,7 +19,6 @@ def test_layer_behaviors(self): expected_output_shape=(2, 4, 6), ) - @pytest.mark.requires_trainable_backend def test_layer_behaviors_4d(self): self.run_layer_test( cls=SinePositionEncoding, diff --git a/keras_hub/src/layers/modeling/token_and_position_embedding_test.py b/keras_hub/src/layers/modeling/token_and_position_embedding_test.py index fae2a73151..f0ef202aed 100644 --- a/keras_hub/src/layers/modeling/token_and_position_embedding_test.py +++ b/keras_hub/src/layers/modeling/token_and_position_embedding_test.py @@ -1,5 +1,4 @@ import numpy as np -import pytest from keras import ops from keras import random @@ -10,7 +9,6 @@ class TokenAndPositionEmbeddingTest(TestCase): - @pytest.mark.requires_trainable_backend def test_layer_behaviors(self): self.run_layer_test( cls=TokenAndPositionEmbedding, diff --git a/keras_hub/src/layers/modeling/transformer_decoder_test.py b/keras_hub/src/layers/modeling/transformer_decoder_test.py index 2af91292c2..7cbd32bed8 100644 --- a/keras_hub/src/layers/modeling/transformer_decoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_decoder_test.py @@ -1,4 +1,3 @@ -import pytest from absl.testing import parameterized from keras import ops from keras import random @@ -12,7 +11,6 @@ class TransformerDecoderTest(TestCase): ("without_norm_first", False), ("with_norm_first", True), ) - @pytest.mark.requires_trainable_backend def test_layer_behaviors(self, normalize_first): self.run_layer_test( cls=TransformerDecoder, @@ -36,7 +34,6 @@ def test_layer_behaviors(self, normalize_first): ("without_norm_first", False), ("with_norm_first", True), ) - @pytest.mark.requires_trainable_backend def test_layer_behaviors_with_cross_attention(self, normalize_first): self.run_layer_test( cls=TransformerDecoder, @@ -92,7 +89,6 @@ def test_value_error_when_invalid_kernel_inititalizer(self): kernel_initializer="Invalid", ) - @pytest.mark.requires_trainable_backend def test_training_propagation(self): decoder = TransformerDecoder( intermediate_dim=4, diff --git a/keras_hub/src/layers/modeling/transformer_encoder_test.py b/keras_hub/src/layers/modeling/transformer_encoder_test.py index e9c58bc01c..a682af157c 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_encoder_test.py @@ -1,5 +1,4 @@ import keras -import pytest from absl.testing import parameterized from keras import ops from keras import random @@ -13,7 +12,6 @@ class TransformerEncoderTest(TestCase): ("without_norm_first", False), ("with_norm_first", True), ) - @pytest.mark.requires_trainable_backend def test_layer_behaviors(self, normalize_first): self.run_layer_test( cls=TransformerEncoder, @@ -71,7 +69,6 @@ def test_value_error_when_invalid_kernel_inititalizer(self): kernel_initializer="Invalid", ) - @pytest.mark.requires_trainable_backend def test_training_propagation(self): encoder = TransformerEncoder( intermediate_dim=4, diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index 70267fad4e..a12ac33303 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -133,80 +133,14 @@ def make_generate_function(self): self.generate_function = self.generate_step if keras.config.backend() == "openvino": - import openvino as ov - import openvino.runtime.opset14 as ov_opset - - from keras_hub.src.utils.keras_utils import print_msg - - def ov_infer(inputs, stop_token_ids, fn): - def get_outputs(inputs, struct_outputs, compiled_ov_model): - flatten_inputs = tree.flatten(inputs) - outputs = compiled_ov_model(flatten_inputs).to_tuple() - outputs = self._unpack_singleton( - tree.pack_sequence_as(struct_outputs, outputs) - ) - return outputs - - core = ov.Core() - device = "GPU" if "GPU" in core.available_devices else "CPU" - - # Try using the existing compiled model - if ( - self.ov_compiled_model is not None - and getattr(self, "ov_device", None) is not None - and device == self.ov_device - ): - try: - return get_outputs( - inputs, self.struct_outputs, self.ov_compiled_model - ) - except RuntimeError as e: - # Delete previous model and struct outputs, then - # Fall through to recompilation if inference fails - print_msg( - "WARNING: OpenVINO inference \033[1mFAILED\033[0m, " - "so we'll Rebuild and compile the model then " - f"try again.\n{e}" - ) - del self.ov_compiled_model - del self.struct_outputs - - # Rebuild and compile the OpenVINO model - struct_params = self._parameterize_data(inputs) - self.struct_outputs = fn(struct_params, stop_token_ids) - parameters = [ - p.output.get_node() for p in tree.flatten(struct_params) - ] - results = [ - ov_opset.result(r.output) - for r in tree.flatten(self.struct_outputs) - ] - ov_model = ov.Model(results=results, parameters=parameters) - for ov_input in ov_model.inputs: - rank = ov_input.get_partial_shape().rank.get_length() - ov_input.get_node().set_partial_shape( - ov.PartialShape([-1] * rank) - ) - ov_model.validate_nodes_and_infer_types() - - self.ov_device = device - model_dtype = ( - "f16" - if self.dtype == "float16" or self.dtype == "bfloat16" - else "f32" - ) - config = {"INFERENCE_PRECISION_HINT": model_dtype} - self.ov_compiled_model = core.compile_model( - ov_model, device, config - ) - return get_outputs( - inputs, self.struct_outputs, self.ov_compiled_model - ) + from keras_hub.src.utils.openvino_utils import ov_infer def wrapped_generate_function(inputs, stop_token_ids=None): - # ops.array converts to numpy in openvino backend + # Convert to numpy for OpenVINO backend inputs = tree.map_structure(ops.array, inputs) - return ov_infer(inputs, stop_token_ids, self.generate_step) + return ov_infer( + self, inputs, stop_token_ids, self.generate_step + ) self.generate_function = wrapped_generate_function if keras.config.backend() == "torch": diff --git a/keras_hub/src/models/gemma/gemma_backbone_test.py b/keras_hub/src/models/gemma/gemma_backbone_test.py index 7279f53323..b5f8575332 100644 --- a/keras_hub/src/models/gemma/gemma_backbone_test.py +++ b/keras_hub/src/models/gemma/gemma_backbone_test.py @@ -23,7 +23,6 @@ def setUp(self): "padding_mask": ops.ones((2, 5), dtype="int32"), } - @pytest.mark.requires_trainable_backend def test_backbone_basics(self): self.run_backbone_test( cls=GemmaBackbone, @@ -181,7 +180,6 @@ def setUp(self): "padding_mask": ops.ones((2, 10), dtype="int32"), } - @pytest.mark.requires_trainable_backend def test_backbone_basics(self): self.run_backbone_test( cls=GemmaBackbone, diff --git a/keras_hub/src/models/gemma/gemma_causal_lm.py b/keras_hub/src/models/gemma/gemma_causal_lm.py index 07c6f1710c..cf2c98c23c 100644 --- a/keras_hub/src/models/gemma/gemma_causal_lm.py +++ b/keras_hub/src/models/gemma/gemma_causal_lm.py @@ -258,9 +258,6 @@ def next(prompt, cache, index): cache_update_index = index - 1 batch_size = ops.shape(prompt)[0] prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - if keras.config.backend() == "openvino": - # Avoid returning dynamic shape by openvino slice - prompt = ops.reshape(prompt, [batch_size, 1]) logits, hidden_states, cache = self.call_with_cache( prompt, cache, diff --git a/keras_hub/src/models/gemma/gemma_causal_lm_test.py b/keras_hub/src/models/gemma/gemma_causal_lm_test.py index 59a40d3b3b..7885d502cc 100644 --- a/keras_hub/src/models/gemma/gemma_causal_lm_test.py +++ b/keras_hub/src/models/gemma/gemma_causal_lm_test.py @@ -52,7 +52,6 @@ def setUp(self): self.train_data = (["the quick brown fox", "the quick brown fox"],) self.input_data = self.preprocessor(*self.train_data)[0] - @pytest.mark.requires_trainable_backend def test_causal_lm_basics(self): self.run_task_test( cls=GemmaCausalLM, diff --git a/keras_hub/src/models/gemma/gemma_lora_test.py b/keras_hub/src/models/gemma/gemma_lora_test.py index 4d0da24523..256c4cf5fd 100644 --- a/keras_hub/src/models/gemma/gemma_lora_test.py +++ b/keras_hub/src/models/gemma/gemma_lora_test.py @@ -1,13 +1,11 @@ import os import numpy as np -import pytest from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone from keras_hub.src.tests.test_case import TestCase -@pytest.mark.requires_trainable_backend class GemmaLoraTest(TestCase): def setUp(self): self._init_kwargs = { diff --git a/keras_hub/src/models/gpt2/gpt2_backbone_test.py b/keras_hub/src/models/gpt2/gpt2_backbone_test.py index 6cf95ecc4c..2ae887fbbc 100644 --- a/keras_hub/src/models/gpt2/gpt2_backbone_test.py +++ b/keras_hub/src/models/gpt2/gpt2_backbone_test.py @@ -20,7 +20,6 @@ def setUp(self): "padding_mask": ops.ones((2, 5), dtype="int32"), } - @pytest.mark.requires_trainable_backend def test_backbone_basics(self): self.run_backbone_test( cls=GPT2Backbone, diff --git a/keras_hub/src/models/gpt2/gpt2_causal_lm.py b/keras_hub/src/models/gpt2/gpt2_causal_lm.py index dc5c4a8ba0..7f29d4ebd8 100644 --- a/keras_hub/src/models/gpt2/gpt2_causal_lm.py +++ b/keras_hub/src/models/gpt2/gpt2_causal_lm.py @@ -246,9 +246,6 @@ def next(prompt, cache, index): cache_update_index = index - 1 batch_size = ops.shape(prompt)[0] prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - if keras.config.backend() == "openvino": - # Avoid returning dynamic shape by openvino slice - prompt = ops.reshape(prompt, [batch_size, 1]) logits, hidden_states, cache = self.call_with_cache( prompt, cache, diff --git a/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py b/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py index deee70d3d1..0f6315bea6 100644 --- a/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py +++ b/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py @@ -39,7 +39,6 @@ def setUp(self): self.train_data = ([" airplane at airport", " airplane at airport"],) self.input_data = self.preprocessor(*self.train_data)[0] - @pytest.mark.requires_trainable_backend def test_causal_lm_basics(self): self.run_task_test( cls=GPT2CausalLM, diff --git a/keras_hub/src/models/mistral/mistral_backbone_test.py b/keras_hub/src/models/mistral/mistral_backbone_test.py index d52784b85a..ffb6e7ef20 100644 --- a/keras_hub/src/models/mistral/mistral_backbone_test.py +++ b/keras_hub/src/models/mistral/mistral_backbone_test.py @@ -21,7 +21,6 @@ def setUp(self): "padding_mask": ops.ones((2, 5), dtype="int32"), } - @pytest.mark.requires_trainable_backend def test_backbone_basics(self): self.run_backbone_test( cls=MistralBackbone, diff --git a/keras_hub/src/models/mistral/mistral_causal_lm.py b/keras_hub/src/models/mistral/mistral_causal_lm.py index 0cacb191fa..d28a7cad26 100644 --- a/keras_hub/src/models/mistral/mistral_causal_lm.py +++ b/keras_hub/src/models/mistral/mistral_causal_lm.py @@ -145,9 +145,6 @@ def next(prompt, cache, index): cache_update_index = index - 1 batch_size = ops.shape(prompt)[0] prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) - if keras.config.backend() == "openvino": - # Avoid returning dynamic shape by openvino slice - prompt = ops.reshape(prompt, [batch_size, 1]) logits, hidden_states, cache = self.call_with_cache( prompt, cache, diff --git a/keras_hub/src/models/mistral/mistral_causal_lm_test.py b/keras_hub/src/models/mistral/mistral_causal_lm_test.py index e1203b6641..8a6bd42434 100644 --- a/keras_hub/src/models/mistral/mistral_causal_lm_test.py +++ b/keras_hub/src/models/mistral/mistral_causal_lm_test.py @@ -39,7 +39,6 @@ def setUp(self): self.train_data = (["the quick brown fox", "the earth is round"],) self.input_data = self.preprocessor(*self.train_data)[0] - @pytest.mark.requires_trainable_backend def test_causal_lm_basics(self): self.run_task_test( cls=MistralCausalLM, diff --git a/keras_hub/src/utils/openvino_utils.py b/keras_hub/src/utils/openvino_utils.py new file mode 100644 index 0000000000..55a36ff764 --- /dev/null +++ b/keras_hub/src/utils/openvino_utils.py @@ -0,0 +1,311 @@ +import ast +import functools +from pathlib import Path + +from keras import tree + +from keras_hub.src.utils.keras_utils import print_msg + +try: + import openvino as ov + import openvino.opset14 as ov_opset + from openvino import Core + + core = Core() +except ImportError: + ov = None + ov_opset = None + core = None + + +def load_openvino_supported_tools(config_file_path): + """Load OpenVINO supported models from whitelist file. + + Args: + config_file_path: Path to whitelist file. + + Returns: + list: Supported model paths. + """ + try: + with open(config_file_path, "r") as f: + return [ + line.strip() + for line in f + if line.strip() and not line.strip().startswith("#") + ] + except FileNotFoundError: + return [] + + +def setup_openvino_test_config(config_file_path): + """Setup OpenVINO test configuration with whitelist approach. + + Args: + config_file_path: Path to the config file directory. + + Returns: + list: Supported paths (whitelist) for OpenVINO testing. + """ + return load_openvino_supported_tools( + Path(config_file_path) / "openvino_supported_tests.txt" + ) + + +@functools.lru_cache(maxsize=256) +def _contains_training_methods(file_path, test_name): + """Check if a test function contains training methods. + + Args: + file_path: Path to the test file. + test_name: Name of the test function. + + Returns: + bool: True if training methods found, False otherwise. + """ + training_methods = { + "fit", + "fit_generator", + "train_on_batch", + "compile", + "train_step", + "train", + "backward", + "zero_grad", + "step", + } + + training_keywords = { + "optimizer", + "loss", + "epochs", + "batch_size", + "learning_rate", + } + + training_test_methods = { + "run_layer_test", + "run_training_step", + "run_build_asserts", + "run_task_test", + "run_preprocessing_layer_test", + } + + class TrainingMethodDetector(ast.NodeVisitor): + def __init__(self): + self.has_training_methods = False + + def visit_Call(self, node): + if ( + hasattr(node.func, "attr") + and node.func.attr in training_methods + ): + self.has_training_methods = True + + if ( + hasattr(node.func, "attr") + and node.func.attr in training_test_methods + ): + self.has_training_methods = True + + if ( + hasattr(node.func, "value") + and hasattr(node.func.value, "id") + and node.func.value.id == "self" + and hasattr(node.func, "attr") + and node.func.attr in training_test_methods + ): + self.has_training_methods = True + + self.generic_visit(node) + + def visit_keyword(self, node): + """Visit keyword arguments to detect training keywords.""" + if node.arg in training_keywords: + self.has_training_methods = True + self.generic_visit(node) + + try: + with open(file_path, "r", encoding="utf-8") as f: + source = f.read() + tree = ast.parse(source) + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == test_name: + detector = TrainingMethodDetector() + detector.visit(node) + return detector.has_training_methods + return False + except (OSError, SyntaxError): + return True + + +def should_auto_skip_training_test(item): + """Check if test should be auto-skipped for OpenVINO training ops. + + Args: + item: Pytest test item. + + Returns: + bool: True if should skip, False otherwise. + """ + if not str(item.fspath).endswith(".py"): + return False + test_name = item.name.split("[")[0] + return _contains_training_methods(str(item.fspath), test_name) + + +def get_openvino_skip_reason(item, supported_paths, auto_skip_training=True): + """Whitelist-based OpenVINO test skip checker. + + Only tests files/directories in supported_paths, skips everything else. + + Args: + item: Pytest test item. + supported_paths: List of supported file/directory paths (whitelist). + auto_skip_training: Whether to auto-skip training tests. + + Returns: + str or None: Skip reason if should skip, None otherwise. + """ + test_name = item.name.split("[")[0] + test_path = str(item.fspath) + + # Priority 1: Skip specific problematic test methods + SPECIFIC_SKIPPING_TESTS = { + "test_backbone_basics": "Requires trainable backend", + "test_score_loss": "Non-implemented roll operation", + "test_layer_behaviors": "Requires trainable backend", + } + if test_name in SPECIFIC_SKIPPING_TESTS: + return SPECIFIC_SKIPPING_TESTS[test_name] + + # Priority 2: Skip training operations (if enabled) + if auto_skip_training and should_auto_skip_training_test(item): + return "Training operations not supported" + + # Priority 3: Whitelist-based approach - only test supported paths + if supported_paths: + parts = test_path.replace("\\", "/").split("/") + try: + keras_hub_idx = parts.index("keras_hub") + relative_test_path = "/".join(parts[keras_hub_idx:]) + except ValueError: + relative_test_path = test_path # fall back to absolute + + for supported_path in supported_paths: + if ( + relative_test_path == supported_path + or relative_test_path.startswith(supported_path + "/") + ): + return None # in whitelist + + return "File/directory not in OpenVINO whitelist" + + return None + + +def get_device(): + """Detect and return the best available OpenVINO device. + + Returns: + str: "GPU" if available, otherwise "CPU". + """ + return "GPU" if "GPU" in core.available_devices else "CPU" + + +def compile_model(struct_params, struct_outputs, device, model_dtype): + """Compile OpenVINO model with dynamic shapes and precision hints. + + Args: + struct_params: Model parameters structure. + struct_outputs: Model outputs structure. + device: Target device ("GPU" or "CPU"). + model_dtype: Model precision ("f16" or "f32"). + + Returns: + Compiled OpenVINO model ready for inference. + """ + flat_params = tree.flatten(struct_params) + flat_outputs = tree.flatten(struct_outputs) + parameters = [p.output.get_node() for p in flat_params] + results = [ov_opset.result(r.output) for r in flat_outputs] + ov_model = ov.Model(results=results, parameters=parameters) + for ov_input in ov_model.inputs: + rank = ov_input.get_partial_shape().rank.get_length() + ov_input.get_node().set_partial_shape(ov.PartialShape([-1] * rank)) + ov_model.validate_nodes_and_infer_types() + config = {"INFERENCE_PRECISION_HINT": model_dtype} + return core.compile_model(ov_model, device, config) + + +def get_outputs(inputs, struct_outputs, compiled_ov_model, unpack_singleton): + """Execute compiled OpenVINO model and return structured outputs. + + Args: + inputs: Input tensors for inference. + struct_outputs: Expected output structure. + compiled_ov_model: Compiled OpenVINO model. + unpack_singleton: Function to unpack singleton outputs. + + Returns: + Structured model outputs matching expected format. + """ + flatten_inputs = tree.flatten(inputs) + raw = compiled_ov_model(flatten_inputs).to_tuple() + packed = tree.pack_sequence_as(struct_outputs, raw) + return unpack_singleton(packed) + + +def ov_infer(model, inputs, stop_token_ids, fn): + """High-level OpenVINO inference with model reuse and compilation. + + This function manages OpenVINO model compilation and caching. It reuses + existing compiled models when possible, or compiles new ones as needed. + Handles device detection and automatic precision selection. + + Args: + model: Keras model with OpenVINO backend support. + inputs: Input tensors for inference. + stop_token_ids: Token IDs that should stop generation. + fn: Function to execute with the parameterized inputs. + + Returns: + Model outputs from OpenVINO inference. + """ + device = get_device() + + # Try to use existing compiled model for the same device + if ( + getattr(model, "ov_compiled_model", None) is not None + and getattr(model, "ov_device", None) is not None + and device == model.ov_device + ): + try: + return get_outputs( + inputs, + model.struct_outputs, + model.ov_compiled_model, + model._unpack_singleton, + ) + except RuntimeError as e: + print_msg( + "WARNING: OpenVINO inference \033[1mFAILED\033[0m, " + "recompiling model and trying again.\n" + str(e) + ) + model.ov_compiled_model = None + model.struct_outputs = None + + # Compile a new model + struct_params = model._parameterize_data(inputs) + model.struct_outputs = fn(struct_params, stop_token_ids) + model.ov_device = device + model_dtype = "f16" if model.dtype in ("float16", "bfloat16") else "f32" + model.ov_compiled_model = compile_model( + struct_params, model.struct_outputs, device, model_dtype + ) + return get_outputs( + inputs, + model.struct_outputs, + model.ov_compiled_model, + model._unpack_singleton, + ) diff --git a/keras_hub/src/utils/openvino_utils_test.py b/keras_hub/src/utils/openvino_utils_test.py new file mode 100644 index 0000000000..bb62bbff6e --- /dev/null +++ b/keras_hub/src/utils/openvino_utils_test.py @@ -0,0 +1,386 @@ +import os +import tempfile +import unittest.mock + +import keras +import numpy as np +import pytest + +from keras_hub.src.tests.test_case import TestCase + +try: + import openvino as ov + from openvino import Core + + from keras_hub.src.utils.openvino_utils import _contains_training_methods + from keras_hub.src.utils.openvino_utils import compile_model + from keras_hub.src.utils.openvino_utils import get_device + from keras_hub.src.utils.openvino_utils import get_openvino_skip_reason + from keras_hub.src.utils.openvino_utils import get_outputs + from keras_hub.src.utils.openvino_utils import load_openvino_supported_tools + from keras_hub.src.utils.openvino_utils import ov_infer + from keras_hub.src.utils.openvino_utils import setup_openvino_test_config + from keras_hub.src.utils.openvino_utils import ( + should_auto_skip_training_test, + ) +except ImportError: + ov = None + Core = None + + +# --- shared test helpers --- +class _MockParam: + @property + def output(self): + return self + + def get_node(self): + return unittest.mock.MagicMock() + + +class _MockOutput: + @property + def output(self): + return unittest.mock.MagicMock() + + +class _MockFspath: + def __init__(self, path): + import os + + self.path = path + self.basename = os.path.basename(path) + + def __str__(self): + return self.path + + +class _MockItem: + def __init__(self, fspath, name): + self.fspath = ( + _MockFspath(fspath) if not hasattr(fspath, "basename") else fspath + ) + self.name = name + + +@pytest.mark.skipif( + keras.config.backend() != "openvino", + reason="OpenVINO tests only run with OpenVINO backend", +) +class OpenVINOUtilsTest(TestCase): + def setUp(self): + super().setUp() + if ov is None: + self.skipTest("OpenVINO not available") + + def test_get_device_returns_valid_device(self): + device = get_device() + self.assertIn(device, ["GPU", "CPU"]) + + core = Core() + self.assertIn(device, core.available_devices) + + def test_get_device_consistency(self): + device1 = get_device() + device2 = get_device() + self.assertEqual(device1, device2) + + def test_compile_model_basic_and_precision_hints(self): + with ( + unittest.mock.patch( + "keras_hub.src.utils.openvino_utils.ov.Model" + ) as mock_model_class, + unittest.mock.patch( + "keras_hub.src.utils.openvino_utils.core" + ) as mock_core, + ): + mock_model_class.return_value = unittest.mock.MagicMock() + mock_core.compile_model.return_value = unittest.mock.MagicMock() + + struct_params = [_MockParam(), _MockParam()] + struct_outputs = [_MockOutput()] + device = "CPU" + + for dtype in ("f32", "f16"): + with self.subTest(dtype=dtype): + result = compile_model( + struct_params, struct_outputs, device, dtype + ) + self.assertIsNotNone(result) + + self.assertEqual(mock_core.compile_model.call_count, 2) + + def test_get_outputs_basic_functionality(self): + class MockResult: + def __init__(self, data): + self.data = data + + def to_tuple(self): + return (self.data,) + + class MockCompiledModel: + def __init__(self): + self.inputs = ["input"] + self.outputs = ["output"] + + def __call__(self, flatten_inputs): + input_data = flatten_inputs[0] + output_data = np.maximum(input_data, 0.0) + return MockResult(output_data) + + class MockOutput: + def get_node(self): + return "mock_relu_node" + + compiled_model = MockCompiledModel() + struct_outputs = [MockOutput()] + + test_input = np.array([[-1.0, 0.0, 1.0]], dtype=np.float32) + inputs = [test_input] + + def mock_unpack_singleton(x): + return x[0] if len(x) == 1 else x + + outputs = get_outputs( + inputs, struct_outputs, compiled_model, mock_unpack_singleton + ) + expected = np.array([[0.0, 0.0, 1.0]], dtype=np.float32) + np.testing.assert_array_almost_equal(outputs, expected) + + def test_ov_infer_model_caching(self): + current_device = get_device() + + class MockModel: + def __init__(self): + self.dtype = "float32" + self.ov_compiled_model = unittest.mock.MagicMock() + self.ov_device = current_device + self.struct_outputs = ["mock_output"] + + def _parameterize_data(self, inputs): + return ["mock_param"] + + def _unpack_singleton(self, x): + return x[0] if len(x) == 1 else x + + def mock_fn(struct_params, stop_token_ids): + return ["mock_output"] + + model = MockModel() + test_input = [np.array([[1.0, 2.0, 3.0]], dtype=np.float32)] + cached_model = model.ov_compiled_model + + with unittest.mock.patch( + "keras_hub.src.utils.openvino_utils.get_outputs" + ) as mock_get_outputs: + mock_get_outputs.return_value = np.array( + [[2.0, 4.0, 6.0]], dtype=np.float32 + ) + result = ov_infer(model, test_input, None, mock_fn) + + self.assertIs(model.ov_compiled_model, cached_model) + self.assertIsNotNone(result) + + def test_ov_infer_dtype_selection(self): + class MockModel: + def __init__(self, dtype): + self.dtype = dtype + self.ov_compiled_model = None + self.ov_device = None + self.struct_outputs = None + + def _parameterize_data(self, inputs): + return ["mock_param"] + + def _unpack_singleton(self, x): + return x[0] if len(x) == 1 else x + + def mock_fn(struct_params, stop_token_ids): + return ["mock_output"] + + test_cases = [ + ("float32", "f32"), + ("float16", "f16"), + ("bfloat16", "f16"), + ] + for model_dtype, expected_ov_dtype in test_cases: + with self.subTest(dtype=model_dtype): + model = MockModel(model_dtype) + test_input = [np.array([[1.0, 2.0]], dtype=np.float32)] + with ( + unittest.mock.patch( + "keras_hub.src.utils.openvino_utils.compile_model" + ) as mock_compile, + unittest.mock.patch( + "keras_hub.src.utils.openvino_utils.get_outputs" + ) as mock_get_outputs, + ): + mock_compile.return_value = "mock_compiled_model" + mock_get_outputs.return_value = np.array( + [[1.0, 2.0]], dtype=np.float32 + ) + ov_infer(model, test_input, None, mock_fn) + args, kwargs = mock_compile.call_args + self.assertEqual(args[3], expected_ov_dtype) + + def test_load_openvino_supported_tools_valid_file(self): + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".txt" + ) as f: + f.write("keras_hub/src/models/gemma\n") + f.write("keras_hub/src/models/gpt2\n") + f.write("keras_hub/src/layers/modeling\n") + temp_file = f.name + + try: + result = load_openvino_supported_tools(temp_file) + expected = [ + "keras_hub/src/models/gemma", + "keras_hub/src/models/gpt2", + "keras_hub/src/layers/modeling", + ] + self.assertEqual(result, expected) + finally: + os.unlink(temp_file) + + def test_load_openvino_supported_tools_nonexistent_file(self): + result = load_openvino_supported_tools("/nonexistent/file.txt") + self.assertEqual(result, []) + + def test_load_openvino_supported_tools_empty_file(self): + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".txt" + ) as f: + temp_file = f.name + try: + result = load_openvino_supported_tools(temp_file) + self.assertEqual(result, []) + finally: + os.unlink(temp_file) + + def test_setup_openvino_test_config_openvino_backend(self): + with tempfile.TemporaryDirectory() as temp_dir: + config_file = os.path.join(temp_dir, "openvino_supported_tests.txt") + with open(config_file, "w") as f: + f.write("keras_hub/src/models/gemma\n") + f.write("keras_hub/src/tokenizers\n") + + result = setup_openvino_test_config(temp_dir) + expected = [ + "keras_hub/src/models/gemma", + "keras_hub/src/tokenizers", + ] + self.assertEqual(result, expected) + + def test_contains_training_methods_with_training_code(self): + training_code = """ + import keras + def test_training_method(): + model = keras.Model() + model.fit(x, y) + return model + def test_other_method(): + model.compile(optimizer='adam') + return model + """ + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".py" + ) as f: + f.write(training_code) + temp_file = f.name + try: + result = _contains_training_methods( + temp_file, "test_training_method" + ) + self.assertTrue(result) + finally: + os.unlink(temp_file) + + def test_contains_training_methods_nonexistent_file(self): + result = _contains_training_methods( + "/nonexistent/file.py", "test_method" + ) + self.assertTrue(result) + + def test_should_auto_skip_training_test_non_python_file(self): + class _SimpleItem: + def __init__(self, fspath): + self.fspath = type("MockPath", (), {"basename": fspath})() + self.name = "test_method" + + item = _SimpleItem("test_file.txt") + result = should_auto_skip_training_test(item) + self.assertFalse(result) + + def test_should_auto_skip_training_test_with_training_methods(self): + training_code = """ + def test_fit_method(): + model.fit(x, y) + return model + """ + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".py" + ) as f: + f.write(training_code) + temp_file = f.name + try: + item = _MockItem(temp_file, "test_fit_method") + result = should_auto_skip_training_test(item) + self.assertTrue(result) + finally: + os.unlink(temp_file) + + def test_get_openvino_skip_reason_specific_test_skip(self): + class MockItem: + def __init__(self, test_name): + self.name = test_name + self.fspath = type("MockPath", (), {})() + setattr(self.fspath, "__str__", lambda: "test_file.py") + + expected_reasons = { + "test_backbone_basics": "Requires trainable backend", + "test_score_loss": "Non-implemented roll operation", + "test_layer_behaviors": "Requires trainable backend", + } + for test_name, expected_reason in expected_reasons.items(): + item = MockItem(test_name) + result = get_openvino_skip_reason(item, [], True) + self.assertEqual(result, expected_reason) + + def test_get_openvino_skip_reason_training_skip(self): + training_code = """ + def test_training_method(): + model.fit(x, y) + return model + """ + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".py" + ) as f: + f.write(training_code) + temp_file = f.name + try: + item = _MockItem(temp_file, "test_training_method") + result = get_openvino_skip_reason(item, [], True) + self.assertEqual(result, "Training operations not supported") + finally: + os.unlink(temp_file) + + def test_get_openvino_skip_reason_whitelist_supported(self): + test_path = "/some/path/keras_hub/src/models/gemma/gemma_test.py" + supported_paths = ["keras_hub/src/models/gemma"] + item = _MockItem(test_path, "test_inference") + result = get_openvino_skip_reason(item, supported_paths, False) + self.assertIsNone(result) + + def test_get_openvino_skip_reason_whitelist_not_supported(self): + test_path = "/some/path/keras_hub/src/models/gemma3/gemma3_test.py" + supported_paths = ["keras_hub/src/models/gemma"] + item = _MockItem(test_path, "test_inference") + result = get_openvino_skip_reason(item, supported_paths, False) + self.assertEqual(result, "File/directory not in OpenVINO whitelist") + + def test_get_openvino_skip_reason_no_whitelist(self): + test_path = "/some/path/keras_hub/src/models/gemma/gemma_test.py" + item = _MockItem(test_path, "test_inference") + result = get_openvino_skip_reason(item, [], False) + self.assertIsNone(result) diff --git a/keras_hub/src/utils/pipeline_model_test.py b/keras_hub/src/utils/pipeline_model_test.py index 42178f54b8..101171597c 100644 --- a/keras_hub/src/utils/pipeline_model_test.py +++ b/keras_hub/src/utils/pipeline_model_test.py @@ -2,7 +2,6 @@ import keras import numpy as np -import pytest import tensorflow as tf from keras_hub.src.tests.test_case import TestCase @@ -79,7 +78,6 @@ def from_config(cls, config): class TestNoopPipelineModel(TestCase): - @pytest.mark.requires_trainable_backend def test_fit(self): x = np.random.uniform(size=(8, 5)) y = np.random.uniform(size=(8, 1)) @@ -113,7 +111,6 @@ def test_predict(self): model.predict(x=x, batch_size=8) model.predict(tf.data.Dataset.from_tensor_slices(x).batch(8)) - @pytest.mark.requires_trainable_backend def test_on_batch(self): x = np.random.uniform(size=(8, 5)) y = np.random.uniform(size=(8, 1)) @@ -146,7 +143,6 @@ def test_saved_model(self): class TestFeaturePreprocessingModel(TestCase): - @pytest.mark.requires_trainable_backend def test_fit_with_preprocessing(self): x = tf.strings.as_string(np.random.uniform(size=(100, 5))) y = np.random.uniform(size=(100, 1)) @@ -180,7 +176,6 @@ def test_predict_with_preprocessing(self): model.predict(x=x, batch_size=8) model.predict(tf.data.Dataset.from_tensor_slices(x).batch(8)) - @pytest.mark.requires_trainable_backend def test_on_batch(self): x = tf.strings.as_string(np.random.uniform(size=(8, 5))) y = np.random.uniform(size=(8, 1)) @@ -213,7 +208,6 @@ def test_saved_model(self): class TestLabelPreprocessingModel(TestCase): - @pytest.mark.requires_trainable_backend def test_fit_with_preprocessing(self): x = np.random.uniform(size=(100, 5)) y = tf.strings.as_string(np.random.uniform(size=(100, 1))) @@ -279,7 +273,6 @@ def test_saved_model(self): class TestDataPreprocessingModel(TestCase): - @pytest.mark.requires_trainable_backend def test_fit_with_preprocessing(self): data = tf.strings.as_string(np.random.uniform(size=(100, 1))) model = DataPipeline() @@ -331,7 +324,6 @@ def test_saved_model(self): class TestFunctional(TestCase): - @pytest.mark.requires_trainable_backend def test_fit(self): x = tf.strings.as_string(np.random.uniform(size=(100, 5))) y = np.random.uniform(size=(100, 1)) @@ -363,7 +355,6 @@ def test_saved_model(self): self.assertAllClose(model_output, restored_output) -@pytest.mark.requires_trainable_backend class TestFitArguments(TestCase): def test_validation_data(self): x = tf.strings.as_string(np.random.uniform(size=(80, 5))) @@ -409,7 +400,6 @@ def test_error_dataset_and_invalid_arguments(self): model.fit(ds, sample_weight=sw) -@pytest.mark.requires_trainable_backend class TestInputErrors(TestCase): def test_unbatched_input_raises(self): model = FeaturePipeline() diff --git a/openvino_excluded_concrete_tests.txt b/openvino_excluded_concrete_tests.txt deleted file mode 100644 index 0d1c851933..0000000000 --- a/openvino_excluded_concrete_tests.txt +++ /dev/null @@ -1,24 +0,0 @@ -AnchorGeneratorTest::test_anchor_generator0 -BoxMatcherTest::test_box_matcher_batched -BoxMatcherTest::test_box_matcher_empty_gt_boxes -BoxMatcherTest::test_box_matcher_force_match -BoxMatcherTest::test_box_matcher_unbatched -CachedMultiHeadAttentionTest::test_cache_call_is_correct -CachedMultiHeadAttentionTest::test_layer_behaviors -GemmaCausalLMTest::test_score_loss -GPT2CausalLMTest::test_score_loss -MistralCausalLMTest::test_score_loss -NonMaxSupressionTest::test_confidence_threshold -NonMaxSupressionTest::test_max_detections -RandomSamplerTest::test_early_stopping -RandomSamplerTest::test_stateful_call -ReversibleEmbeddingTest::test_quantize_dtype_argument_untie_weights -ReversibleEmbeddingTest::test_quantize_dtype_argument_tie_weights -ReversibleEmbeddingTest::test_quantize_int8_tie_weights -ReversibleEmbeddingTest::test_quantize_int8_untie_weights -ReversibleEmbeddingTest::test_saving_tie_weights -ReversibleEmbeddingTest::test_saving_untie_weights -TestNoopPipelineModel::test_evaluate -TestFeaturePreprocessingModel::test_evaluate_with_preprocessing -TestLabelPreprocessingModel::test_evaluate_with_preprocessing -TestDataPreprocessingModel::test_evaluate_with_preprocessing diff --git a/openvino_excluded_tests.txt b/openvino_excluded_tests.txt deleted file mode 100644 index 877d0a0ba3..0000000000 --- a/openvino_excluded_tests.txt +++ /dev/null @@ -1,76 +0,0 @@ -keras_hub/src/layers/modeling/transformer_decoder_test.py -keras_hub/src/layers/modeling/transformer_encoder_test.py -keras_hub/src/layers/preprocessing/image_converter_test.py -keras_hub/src/metrics/bleu_test.py -keras_hub/src/metrics/edit_distance_test.py -keras_hub/src/metrics/perplexity_test.py -keras_hub/src/metrics/rouge_l_test.py -keras_hub/src/metrics/rouge_n_test.py -keras_hub/src/models/albert -keras_hub/src/models/audio_to_text_preprocessor_test.py -keras_hub/src/models/backbone_test.py -keras_hub/src/models/bart -keras_hub/src/models/basnet -keras_hub/src/models/bert -keras_hub/src/models/bloom -keras_hub/src/models/causal_lm_preprocessor_test.py -keras_hub/src/models/clip -keras_hub/src/models/cspnet -keras_hub/src/models/deberta_v3 -keras_hub/src/models/deeplab_v3 -keras_hub/src/models/deit -keras_hub/src/models/densenet -keras_hub/src/models/dinov2 -keras_hub/src/models/distil_bert -keras_hub/src/models/efficientnet -keras_hub/src/models/electra -keras_hub/src/models/esm -keras_hub/src/models/falcon -keras_hub/src/models/flux -keras_hub/src/models/f_net -keras_hub/src/models/gemma3 -keras_hub/src/models/gpt_neo_x -keras_hub/src/models/hgnetv2 -keras_hub/src/models/llama3 -keras_hub/src/models/llama -keras_hub/src/models/masked_lm_preprocessor_test.py -keras_hub/src/models/mit -keras_hub/src/models/mixtral -keras_hub/src/models/mobilenet -keras_hub/src/models/moonshine -keras_hub/src/models/opt -keras_hub/src/models/pali_gemma -keras_hub/src/models/phi3 -keras_hub/src/models/preprocessor_test.py -keras_hub/src/models/qwen3 -keras_hub/src/models/qwen_moe -keras_hub/src/models/qwen -keras_hub/src/models/resnet -keras_hub/src/models/retinanet -keras_hub/src/models/roberta -keras_hub/src/models/roformer_v2 -keras_hub/src/models/sam -keras_hub/src/models/segformer -keras_hub/src/models/seq_2_seq_lm_preprocessor_test.py -keras_hub/src/models/siglip -keras_hub/src/models/stable_diffusion_3 -keras_hub/src/models/t5 -keras_hub/src/models/task_test.py -keras_hub/src/models/text_classifier_preprocessor_test.py -keras_hub/src/models/text_to_image_preprocessor_test.py -keras_hub/src/models/vae -keras_hub/src/models/vgg -keras_hub/src/models/vit_det -keras_hub/src/models/vit -keras_hub/src/models/vit -keras_hub/src/models/whisper -keras_hub/src/models/xception -keras_hub/src/models/xlm_roberta -keras_hub/src/models/xlnet -keras_hub/src/samplers/beam_sampler_test.py -keras_hub/src/samplers/contrastive_sampler_test.py -keras_hub/src/samplers/greedy_sampler_test.py -keras_hub/src/samplers/top_k_sampler_test.py -keras_hub/src/samplers/top_p_sampler_test.py -keras_hub/src/utils/pipeline_model_test.py -keras_hub/src/utils/transformers/export/gemma_test.py \ No newline at end of file diff --git a/openvino_supported_tests.txt b/openvino_supported_tests.txt new file mode 100644 index 0000000000..2495a19a40 --- /dev/null +++ b/openvino_supported_tests.txt @@ -0,0 +1,21 @@ +keras-hub/integration_tests +keras_hub/src/layers/modeling/alibi_bias_test.py +keras_hub/src/layers/modeling/masked_lm_head_test.py +keras_hub/src/layers/modeling/position_embedding_test.py +keras_hub/src/layers/modeling/rotary_embedding_test.py +keras_hub/src/layers/modeling/sine_position_encoding_test.py +keras_hub/src/layers/modeling/token_and_position_embedding_test.py +keras_hub/src/layers/modeling/transformer_layer_utils_test.py +keras_hub/src/layers/preprocessing/audio_converter_test.py +keras_hub/src/layers/preprocessing/masked_lm_mask_generator_test.py +keras_hub/src/layers/preprocessing/multi_segment_packer_test.py +keras_hub/src/layers/preprocessing/random_deletion_test.py +keras_hub/src/layers/preprocessing/random_swap_test.py +keras_hub/src/layers/preprocessing/start_end_packer_test.py +keras_hub/src/models/gemma +keras_hub/src/models/gpt2 +keras_hub/src/models/mistral +keras_hub/src/samplers/serialization_test.py +keras_hub/src/tests/doc_tests/docstring_test.py +keras_hub/src/tokenizers +keras_hub/src/utils \ No newline at end of file From 06a2a8e153c1adf2feb3ac227af56d4c27acd9d9 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Wed, 3 Sep 2025 03:57:05 +0300 Subject: [PATCH 16/17] add suggested updates --- conftest.py | 91 +++++--- keras_hub/src/samplers/beam_sampler.py | 12 +- keras_hub/src/utils/openvino_utils.py | 204 ++---------------- keras_hub/src/utils/openvino_utils_test.py | 239 ++------------------- openvino_supported_tests.txt | 21 -- 5 files changed, 108 insertions(+), 459 deletions(-) delete mode 100644 openvino_supported_tests.txt diff --git a/conftest.py b/conftest.py index 3dd9323a04..5c9e541459 100644 --- a/conftest.py +++ b/conftest.py @@ -1,12 +1,8 @@ import os -from pathlib import Path import keras import pytest -from keras_hub.src.utils.openvino_utils import get_openvino_skip_reason -from keras_hub.src.utils.openvino_utils import setup_openvino_test_config - def pytest_addoption(parser): parser.addoption( @@ -33,16 +29,27 @@ def pytest_addoption(parser): default=False, help="fail if a gpu is not present", ) - parser.addoption( - "--auto_skip_training", - action="store_true", - default=True, - help="automatically skip tests with " - "training methods on non-trainable backends", - ) def pytest_configure(config): + # Monkey-patch training methods for OpenVINO backend + if keras.config.backend() == "openvino": + # Store original methods in case we need to restore them + if not hasattr(keras.Model, "_original_compile"): + keras.Model._original_compile = keras.Model.compile + keras.Model._original_fit = keras.Model.fit + keras.Model._original_train_on_batch = keras.Model.train_on_batch + + keras.Model.compile = lambda *args, **kwargs: pytest.skip( + "Model.compile() not supported on OpenVINO backend" + ) + keras.Model.fit = lambda *args, **kwargs: pytest.skip( + "Model.fit() not supported on OpenVINO backend" + ) + keras.Model.train_on_batch = lambda *args, **kwargs: pytest.skip( + "Model.train_on_batch() not supported on OpenVINO backend" + ) + # Verify that device has GPU and detected by backend if config.getoption("--check_gpu"): found_gpu = False @@ -84,12 +91,9 @@ def pytest_configure(config): def pytest_collection_modifyitems(config, items): - openvino_supported_paths = None - run_extra_large_tests = config.getoption("--run_extra_large") # Run large tests for --run_extra_large or --run_large. run_large_tests = config.getoption("--run_large") or run_extra_large_tests - auto_skip_training = config.getoption("--auto_skip_training") # Messages to annotate skipped tests with. skip_large = pytest.mark.skipif( @@ -124,21 +128,58 @@ def pytest_collection_modifyitems(config, items): if "kaggle_key_required" in item.keywords: item.add_marker(kaggle_key_required) - # OpenVINO-specific skipping logic - whitelist-based approach + # OpenVINO-specific test skipping if keras.config.backend() == "openvino": - # OpenVINO backend configuration - if openvino_supported_paths is None: - openvino_supported_paths = setup_openvino_test_config( - str(Path(__file__).parent) + test_name = item.name.split("[")[0] + test_path = str(item.fspath) + + # OpenVINO supported test paths + openvino_supported_paths = [ + "keras-hub/integration_tests", + "keras_hub/src/models/gemma", + "keras_hub/src/models/gpt2", + "keras_hub/src/models/mistral", + "keras_hub/src/samplers/serialization_test.py", + "keras_hub/src/tests/doc_tests/docstring_test.py", + "keras_hub/src/tokenizers", + "keras_hub/src/utils", + ] + + # Skip specific problematic test methods + specific_skipping_tests = { + "test_backbone_basics": "Requires trainable backend", + "test_score_loss": "Non-implemented roll operation", + "test_layer_behaviors": "Requires trainable backend", + } + + if test_name in specific_skipping_tests: + item.add_marker( + pytest.mark.skipif( + True, + reason="OpenVINO: " + f"{specific_skipping_tests[test_name]}", + ) ) - skip_reason = get_openvino_skip_reason( - item, - openvino_supported_paths, - auto_skip_training, + continue + + parts = test_path.replace("\\", "/").split("/") + try: + keras_hub_idx = parts.index("keras_hub") + relative_test_path = "/".join(parts[keras_hub_idx:]) + except ValueError: + relative_test_path = test_path + + is_whitelisted = any( + relative_test_path == supported_path + or relative_test_path.startswith(supported_path + "/") + for supported_path in openvino_supported_paths ) - if skip_reason: + + if not is_whitelisted: item.add_marker( - pytest.mark.skipif(True, reason=f"OpenVINO: {skip_reason}") + pytest.mark.skipif( + True, reason="OpenVINO: File/directory not in whitelist" + ) ) diff --git a/keras_hub/src/samplers/beam_sampler.py b/keras_hub/src/samplers/beam_sampler.py index 26941e9f3a..c2e605b234 100644 --- a/keras_hub/src/samplers/beam_sampler.py +++ b/keras_hub/src/samplers/beam_sampler.py @@ -95,15 +95,15 @@ def unflatten_beams(x): ) log_probs = flatten_beams(ops.repeat(log_probs, batch_size, axis=0)) - def cond(prompt, cache, index, log_probs): + def cond(prompt, cache, index, mask, log_probs): if stop_token_ids is None: - return True + return ops.convert_to_tensor(True, dtype="bool") # Stop if all sequences have produced a *new* stop token. end_tokens = any_equal(prompt, stop_token_ids, ~mask) prompt_done = ops.any(end_tokens, axis=-1) return ops.logical_not(ops.all(prompt_done)) - def body(prompt, cache, index, log_probs): + def body(prompt, cache, index, mask, log_probs): # Compute the softmax distribution for the next token. logits, _, cache = next(prompt, cache, index) vocab_size = ops.shape(logits)[-1] @@ -150,12 +150,12 @@ def gather_beams(x): next_token = next_token[:, None] prompt = ops.slice_update(prompt, [0, index], next_token) # Return the iteration of the loop state. - return (prompt, cache, index + 1, log_probs) + return (prompt, cache, index + 1, mask, log_probs) - prompt, _, _, log_probs = self.run_loop( + prompt, _, _, _, log_probs = self.run_loop( cond=cond, body=body, - loop_vars=(prompt, cache, index, log_probs), + loop_vars=(prompt, cache, index, mask, log_probs), maximum_iterations=(max_length - index), model=model, ) diff --git a/keras_hub/src/utils/openvino_utils.py b/keras_hub/src/utils/openvino_utils.py index 55a36ff764..0d32579a26 100644 --- a/keras_hub/src/utils/openvino_utils.py +++ b/keras_hub/src/utils/openvino_utils.py @@ -1,7 +1,3 @@ -import ast -import functools -from pathlib import Path - from keras import tree from keras_hub.src.utils.keras_utils import print_msg @@ -11,197 +7,25 @@ import openvino.opset14 as ov_opset from openvino import Core - core = Core() + _core = None except ImportError: ov = None ov_opset = None - core = None - + Core = None + _core = None -def load_openvino_supported_tools(config_file_path): - """Load OpenVINO supported models from whitelist file. - Args: - config_file_path: Path to whitelist file. +def get_core(): + """Get or create OpenVINO Core instance. Returns: - list: Supported model paths. + openvino.Core: OpenVINO Core instance, + or None if OpenVINO not available. """ - try: - with open(config_file_path, "r") as f: - return [ - line.strip() - for line in f - if line.strip() and not line.strip().startswith("#") - ] - except FileNotFoundError: - return [] - - -def setup_openvino_test_config(config_file_path): - """Setup OpenVINO test configuration with whitelist approach. - - Args: - config_file_path: Path to the config file directory. - - Returns: - list: Supported paths (whitelist) for OpenVINO testing. - """ - return load_openvino_supported_tools( - Path(config_file_path) / "openvino_supported_tests.txt" - ) - - -@functools.lru_cache(maxsize=256) -def _contains_training_methods(file_path, test_name): - """Check if a test function contains training methods. - - Args: - file_path: Path to the test file. - test_name: Name of the test function. - - Returns: - bool: True if training methods found, False otherwise. - """ - training_methods = { - "fit", - "fit_generator", - "train_on_batch", - "compile", - "train_step", - "train", - "backward", - "zero_grad", - "step", - } - - training_keywords = { - "optimizer", - "loss", - "epochs", - "batch_size", - "learning_rate", - } - - training_test_methods = { - "run_layer_test", - "run_training_step", - "run_build_asserts", - "run_task_test", - "run_preprocessing_layer_test", - } - - class TrainingMethodDetector(ast.NodeVisitor): - def __init__(self): - self.has_training_methods = False - - def visit_Call(self, node): - if ( - hasattr(node.func, "attr") - and node.func.attr in training_methods - ): - self.has_training_methods = True - - if ( - hasattr(node.func, "attr") - and node.func.attr in training_test_methods - ): - self.has_training_methods = True - - if ( - hasattr(node.func, "value") - and hasattr(node.func.value, "id") - and node.func.value.id == "self" - and hasattr(node.func, "attr") - and node.func.attr in training_test_methods - ): - self.has_training_methods = True - - self.generic_visit(node) - - def visit_keyword(self, node): - """Visit keyword arguments to detect training keywords.""" - if node.arg in training_keywords: - self.has_training_methods = True - self.generic_visit(node) - - try: - with open(file_path, "r", encoding="utf-8") as f: - source = f.read() - tree = ast.parse(source) - for node in ast.walk(tree): - if isinstance(node, ast.FunctionDef) and node.name == test_name: - detector = TrainingMethodDetector() - detector.visit(node) - return detector.has_training_methods - return False - except (OSError, SyntaxError): - return True - - -def should_auto_skip_training_test(item): - """Check if test should be auto-skipped for OpenVINO training ops. - - Args: - item: Pytest test item. - - Returns: - bool: True if should skip, False otherwise. - """ - if not str(item.fspath).endswith(".py"): - return False - test_name = item.name.split("[")[0] - return _contains_training_methods(str(item.fspath), test_name) - - -def get_openvino_skip_reason(item, supported_paths, auto_skip_training=True): - """Whitelist-based OpenVINO test skip checker. - - Only tests files/directories in supported_paths, skips everything else. - - Args: - item: Pytest test item. - supported_paths: List of supported file/directory paths (whitelist). - auto_skip_training: Whether to auto-skip training tests. - - Returns: - str or None: Skip reason if should skip, None otherwise. - """ - test_name = item.name.split("[")[0] - test_path = str(item.fspath) - - # Priority 1: Skip specific problematic test methods - SPECIFIC_SKIPPING_TESTS = { - "test_backbone_basics": "Requires trainable backend", - "test_score_loss": "Non-implemented roll operation", - "test_layer_behaviors": "Requires trainable backend", - } - if test_name in SPECIFIC_SKIPPING_TESTS: - return SPECIFIC_SKIPPING_TESTS[test_name] - - # Priority 2: Skip training operations (if enabled) - if auto_skip_training and should_auto_skip_training_test(item): - return "Training operations not supported" - - # Priority 3: Whitelist-based approach - only test supported paths - if supported_paths: - parts = test_path.replace("\\", "/").split("/") - try: - keras_hub_idx = parts.index("keras_hub") - relative_test_path = "/".join(parts[keras_hub_idx:]) - except ValueError: - relative_test_path = test_path # fall back to absolute - - for supported_path in supported_paths: - if ( - relative_test_path == supported_path - or relative_test_path.startswith(supported_path + "/") - ): - return None # in whitelist - - return "File/directory not in OpenVINO whitelist" - - return None + global _core + if _core is None and Core is not None: + _core = Core() + return _core def get_device(): @@ -210,6 +34,9 @@ def get_device(): Returns: str: "GPU" if available, otherwise "CPU". """ + core = get_core() + if core is None: + return "CPU" return "GPU" if "GPU" in core.available_devices else "CPU" @@ -235,6 +62,9 @@ def compile_model(struct_params, struct_outputs, device, model_dtype): ov_input.get_node().set_partial_shape(ov.PartialShape([-1] * rank)) ov_model.validate_nodes_and_infer_types() config = {"INFERENCE_PRECISION_HINT": model_dtype} + core = get_core() + if core is None: + raise RuntimeError("OpenVINO not available") return core.compile_model(ov_model, device, config) diff --git a/keras_hub/src/utils/openvino_utils_test.py b/keras_hub/src/utils/openvino_utils_test.py index bb62bbff6e..8fe8acc830 100644 --- a/keras_hub/src/utils/openvino_utils_test.py +++ b/keras_hub/src/utils/openvino_utils_test.py @@ -1,66 +1,16 @@ -import os -import tempfile import unittest.mock import keras import numpy as np +import openvino as ov import pytest +from openvino import Core from keras_hub.src.tests.test_case import TestCase - -try: - import openvino as ov - from openvino import Core - - from keras_hub.src.utils.openvino_utils import _contains_training_methods - from keras_hub.src.utils.openvino_utils import compile_model - from keras_hub.src.utils.openvino_utils import get_device - from keras_hub.src.utils.openvino_utils import get_openvino_skip_reason - from keras_hub.src.utils.openvino_utils import get_outputs - from keras_hub.src.utils.openvino_utils import load_openvino_supported_tools - from keras_hub.src.utils.openvino_utils import ov_infer - from keras_hub.src.utils.openvino_utils import setup_openvino_test_config - from keras_hub.src.utils.openvino_utils import ( - should_auto_skip_training_test, - ) -except ImportError: - ov = None - Core = None - - -# --- shared test helpers --- -class _MockParam: - @property - def output(self): - return self - - def get_node(self): - return unittest.mock.MagicMock() - - -class _MockOutput: - @property - def output(self): - return unittest.mock.MagicMock() - - -class _MockFspath: - def __init__(self, path): - import os - - self.path = path - self.basename = os.path.basename(path) - - def __str__(self): - return self.path - - -class _MockItem: - def __init__(self, fspath, name): - self.fspath = ( - _MockFspath(fspath) if not hasattr(fspath, "basename") else fspath - ) - self.name = name +from keras_hub.src.utils.openvino_utils import compile_model +from keras_hub.src.utils.openvino_utils import get_device +from keras_hub.src.utils.openvino_utils import get_outputs +from keras_hub.src.utils.openvino_utils import ov_infer @pytest.mark.skipif( @@ -86,15 +36,26 @@ def test_get_device_consistency(self): self.assertEqual(device1, device2) def test_compile_model_basic_and_precision_hints(self): + class _MockParam: + def __init__(self): + self.output = unittest.mock.MagicMock() + self.output.get_node.return_value = unittest.mock.MagicMock() + + class _MockOutput: + def __init__(self): + self.output = unittest.mock.MagicMock() + with ( unittest.mock.patch( "keras_hub.src.utils.openvino_utils.ov.Model" ) as mock_model_class, unittest.mock.patch( - "keras_hub.src.utils.openvino_utils.core" - ) as mock_core, + "keras_hub.src.utils.openvino_utils.get_core" + ) as mock_get_core, ): mock_model_class.return_value = unittest.mock.MagicMock() + mock_core = unittest.mock.MagicMock() + mock_get_core.return_value = mock_core mock_core.compile_model.return_value = unittest.mock.MagicMock() struct_params = [_MockParam(), _MockParam()] @@ -222,165 +183,3 @@ def mock_fn(struct_params, stop_token_ids): ov_infer(model, test_input, None, mock_fn) args, kwargs = mock_compile.call_args self.assertEqual(args[3], expected_ov_dtype) - - def test_load_openvino_supported_tools_valid_file(self): - with tempfile.NamedTemporaryFile( - mode="w", delete=False, suffix=".txt" - ) as f: - f.write("keras_hub/src/models/gemma\n") - f.write("keras_hub/src/models/gpt2\n") - f.write("keras_hub/src/layers/modeling\n") - temp_file = f.name - - try: - result = load_openvino_supported_tools(temp_file) - expected = [ - "keras_hub/src/models/gemma", - "keras_hub/src/models/gpt2", - "keras_hub/src/layers/modeling", - ] - self.assertEqual(result, expected) - finally: - os.unlink(temp_file) - - def test_load_openvino_supported_tools_nonexistent_file(self): - result = load_openvino_supported_tools("/nonexistent/file.txt") - self.assertEqual(result, []) - - def test_load_openvino_supported_tools_empty_file(self): - with tempfile.NamedTemporaryFile( - mode="w", delete=False, suffix=".txt" - ) as f: - temp_file = f.name - try: - result = load_openvino_supported_tools(temp_file) - self.assertEqual(result, []) - finally: - os.unlink(temp_file) - - def test_setup_openvino_test_config_openvino_backend(self): - with tempfile.TemporaryDirectory() as temp_dir: - config_file = os.path.join(temp_dir, "openvino_supported_tests.txt") - with open(config_file, "w") as f: - f.write("keras_hub/src/models/gemma\n") - f.write("keras_hub/src/tokenizers\n") - - result = setup_openvino_test_config(temp_dir) - expected = [ - "keras_hub/src/models/gemma", - "keras_hub/src/tokenizers", - ] - self.assertEqual(result, expected) - - def test_contains_training_methods_with_training_code(self): - training_code = """ - import keras - def test_training_method(): - model = keras.Model() - model.fit(x, y) - return model - def test_other_method(): - model.compile(optimizer='adam') - return model - """ - with tempfile.NamedTemporaryFile( - mode="w", delete=False, suffix=".py" - ) as f: - f.write(training_code) - temp_file = f.name - try: - result = _contains_training_methods( - temp_file, "test_training_method" - ) - self.assertTrue(result) - finally: - os.unlink(temp_file) - - def test_contains_training_methods_nonexistent_file(self): - result = _contains_training_methods( - "/nonexistent/file.py", "test_method" - ) - self.assertTrue(result) - - def test_should_auto_skip_training_test_non_python_file(self): - class _SimpleItem: - def __init__(self, fspath): - self.fspath = type("MockPath", (), {"basename": fspath})() - self.name = "test_method" - - item = _SimpleItem("test_file.txt") - result = should_auto_skip_training_test(item) - self.assertFalse(result) - - def test_should_auto_skip_training_test_with_training_methods(self): - training_code = """ - def test_fit_method(): - model.fit(x, y) - return model - """ - with tempfile.NamedTemporaryFile( - mode="w", delete=False, suffix=".py" - ) as f: - f.write(training_code) - temp_file = f.name - try: - item = _MockItem(temp_file, "test_fit_method") - result = should_auto_skip_training_test(item) - self.assertTrue(result) - finally: - os.unlink(temp_file) - - def test_get_openvino_skip_reason_specific_test_skip(self): - class MockItem: - def __init__(self, test_name): - self.name = test_name - self.fspath = type("MockPath", (), {})() - setattr(self.fspath, "__str__", lambda: "test_file.py") - - expected_reasons = { - "test_backbone_basics": "Requires trainable backend", - "test_score_loss": "Non-implemented roll operation", - "test_layer_behaviors": "Requires trainable backend", - } - for test_name, expected_reason in expected_reasons.items(): - item = MockItem(test_name) - result = get_openvino_skip_reason(item, [], True) - self.assertEqual(result, expected_reason) - - def test_get_openvino_skip_reason_training_skip(self): - training_code = """ - def test_training_method(): - model.fit(x, y) - return model - """ - with tempfile.NamedTemporaryFile( - mode="w", delete=False, suffix=".py" - ) as f: - f.write(training_code) - temp_file = f.name - try: - item = _MockItem(temp_file, "test_training_method") - result = get_openvino_skip_reason(item, [], True) - self.assertEqual(result, "Training operations not supported") - finally: - os.unlink(temp_file) - - def test_get_openvino_skip_reason_whitelist_supported(self): - test_path = "/some/path/keras_hub/src/models/gemma/gemma_test.py" - supported_paths = ["keras_hub/src/models/gemma"] - item = _MockItem(test_path, "test_inference") - result = get_openvino_skip_reason(item, supported_paths, False) - self.assertIsNone(result) - - def test_get_openvino_skip_reason_whitelist_not_supported(self): - test_path = "/some/path/keras_hub/src/models/gemma3/gemma3_test.py" - supported_paths = ["keras_hub/src/models/gemma"] - item = _MockItem(test_path, "test_inference") - result = get_openvino_skip_reason(item, supported_paths, False) - self.assertEqual(result, "File/directory not in OpenVINO whitelist") - - def test_get_openvino_skip_reason_no_whitelist(self): - test_path = "/some/path/keras_hub/src/models/gemma/gemma_test.py" - item = _MockItem(test_path, "test_inference") - result = get_openvino_skip_reason(item, [], False) - self.assertIsNone(result) diff --git a/openvino_supported_tests.txt b/openvino_supported_tests.txt deleted file mode 100644 index 2495a19a40..0000000000 --- a/openvino_supported_tests.txt +++ /dev/null @@ -1,21 +0,0 @@ -keras-hub/integration_tests -keras_hub/src/layers/modeling/alibi_bias_test.py -keras_hub/src/layers/modeling/masked_lm_head_test.py -keras_hub/src/layers/modeling/position_embedding_test.py -keras_hub/src/layers/modeling/rotary_embedding_test.py -keras_hub/src/layers/modeling/sine_position_encoding_test.py -keras_hub/src/layers/modeling/token_and_position_embedding_test.py -keras_hub/src/layers/modeling/transformer_layer_utils_test.py -keras_hub/src/layers/preprocessing/audio_converter_test.py -keras_hub/src/layers/preprocessing/masked_lm_mask_generator_test.py -keras_hub/src/layers/preprocessing/multi_segment_packer_test.py -keras_hub/src/layers/preprocessing/random_deletion_test.py -keras_hub/src/layers/preprocessing/random_swap_test.py -keras_hub/src/layers/preprocessing/start_end_packer_test.py -keras_hub/src/models/gemma -keras_hub/src/models/gpt2 -keras_hub/src/models/mistral -keras_hub/src/samplers/serialization_test.py -keras_hub/src/tests/doc_tests/docstring_test.py -keras_hub/src/tokenizers -keras_hub/src/utils \ No newline at end of file From 9e54481e1b65f05ad1d42897621f7fe9dc56b6d3 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Sat, 6 Sep 2025 14:49:02 +0300 Subject: [PATCH 17/17] update conftest.py & openvino utils --- conftest.py | 65 ++++++++++----------------- keras_hub/src/utils/openvino_utils.py | 6 +-- 2 files changed, 26 insertions(+), 45 deletions(-) diff --git a/conftest.py b/conftest.py index 5c9e541459..71c2e4ed9c 100644 --- a/conftest.py +++ b/conftest.py @@ -3,6 +3,22 @@ import keras import pytest +# OpenVINO supported test paths +OPENVINO_SUPPORTED_PATHS = [ + "keras-hub/integration_tests", + "keras_hub/src/models/gemma", + "keras_hub/src/models/gpt2", + "keras_hub/src/models/mistral", + "keras_hub/src/tokenizers", +] + +# OpenVINO specific test skips +OPENVINO_SPECIFIC_SKIPPING_TESTS = { + "test_backbone_basics": "bfloat16 dtype not supported", + "test_score_loss": "Non-implemented roll operation", + "test_causal_lm_basics": "Missing ops and requires trainable backend", +} + def pytest_addoption(parser): parser.addoption( @@ -34,15 +50,6 @@ def pytest_addoption(parser): def pytest_configure(config): # Monkey-patch training methods for OpenVINO backend if keras.config.backend() == "openvino": - # Store original methods in case we need to restore them - if not hasattr(keras.Model, "_original_compile"): - keras.Model._original_compile = keras.Model.compile - keras.Model._original_fit = keras.Model.fit - keras.Model._original_train_on_batch = keras.Model.train_on_batch - - keras.Model.compile = lambda *args, **kwargs: pytest.skip( - "Model.compile() not supported on OpenVINO backend" - ) keras.Model.fit = lambda *args, **kwargs: pytest.skip( "Model.fit() not supported on OpenVINO backend" ) @@ -131,48 +138,22 @@ def pytest_collection_modifyitems(config, items): # OpenVINO-specific test skipping if keras.config.backend() == "openvino": test_name = item.name.split("[")[0] - test_path = str(item.fspath) - - # OpenVINO supported test paths - openvino_supported_paths = [ - "keras-hub/integration_tests", - "keras_hub/src/models/gemma", - "keras_hub/src/models/gpt2", - "keras_hub/src/models/mistral", - "keras_hub/src/samplers/serialization_test.py", - "keras_hub/src/tests/doc_tests/docstring_test.py", - "keras_hub/src/tokenizers", - "keras_hub/src/utils", - ] - - # Skip specific problematic test methods - specific_skipping_tests = { - "test_backbone_basics": "Requires trainable backend", - "test_score_loss": "Non-implemented roll operation", - "test_layer_behaviors": "Requires trainable backend", - } - - if test_name in specific_skipping_tests: + + if test_name in OPENVINO_SPECIFIC_SKIPPING_TESTS: item.add_marker( pytest.mark.skipif( True, reason="OpenVINO: " - f"{specific_skipping_tests[test_name]}", + f"{OPENVINO_SPECIFIC_SKIPPING_TESTS[test_name]}", ) ) continue - parts = test_path.replace("\\", "/").split("/") - try: - keras_hub_idx = parts.index("keras_hub") - relative_test_path = "/".join(parts[keras_hub_idx:]) - except ValueError: - relative_test_path = test_path - is_whitelisted = any( - relative_test_path == supported_path - or relative_test_path.startswith(supported_path + "/") - for supported_path in openvino_supported_paths + item.nodeid.startswith(supported_path + "/") + or item.nodeid.startswith(supported_path + "::") + or item.nodeid == supported_path + for supported_path in OPENVINO_SUPPORTED_PATHS ) if not is_whitelisted: diff --git a/keras_hub/src/utils/openvino_utils.py b/keras_hub/src/utils/openvino_utils.py index 0d32579a26..68570e0d15 100644 --- a/keras_hub/src/utils/openvino_utils.py +++ b/keras_hub/src/utils/openvino_utils.py @@ -6,13 +6,13 @@ import openvino as ov import openvino.opset14 as ov_opset from openvino import Core - - _core = None except ImportError: ov = None ov_opset = None Core = None - _core = None + + +_core = None def get_core():