Skip to content

Commit 4327cc8

Browse files
avoid core dumps by jax
1 parent 28f3efc commit 4327cc8

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

keras/src/export/openvino.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from keras.src import backend
24
from keras.src import tree
35
from keras.src.export.export_utils import convert_spec_to_tensor
@@ -108,10 +110,24 @@ def parameterize_inputs(inputs, prefix=""):
108110
inputs = tree.map_structure(make_tf_tensor_spec, input_signature)
109111
decorated_fn = get_concrete_fn(model, inputs, **kwargs)
110112
ov_model = ov.convert_model(decorated_fn)
113+
elif backend.backend() == "torch":
114+
import torch
115+
116+
sample_inputs = tree.map_structure(
117+
lambda x: convert_spec_to_tensor(x, replace_none_number=1),
118+
input_signature,
119+
)
120+
sample_inputs = tuple(sample_inputs)
121+
if hasattr(model, "eval"):
122+
model.eval()
123+
with warnings.catch_warnings():
124+
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
125+
traced = torch.jit.trace(model, sample_inputs)
126+
ov_model = ov.convert_model(traced)
111127
else:
112128
raise NotImplementedError(
113129
"`export_openvino` is only compatible with OpenVINO, "
114-
"TensorFlow and JAX backends."
130+
"TensorFlow, JAX and Torch backends."
115131
)
116132

117133
ov.serialize(ov_model, filepath)
@@ -147,7 +163,6 @@ def _check_jax_kwargs(kwargs):
147163

148164

149165
def get_concrete_fn(model, input_signature, **kwargs):
150-
"""Get the `tf.function` associated with the model."""
151166
if backend.backend() == "jax":
152167
kwargs = _check_jax_kwargs(kwargs)
153168
export_archive = ExportArchive()

keras/src/export/openvino_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,13 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None):
7373

7474
@pytest.mark.skipif(ov is None, reason="OpenVINO is not installed")
7575
@pytest.mark.skipif(
76-
backend.backend() not in ("tensorflow", "openvino", "jax"),
76+
backend.backend() not in ("tensorflow", "openvino", "jax", "torch"),
7777
reason=(
7878
"`export_openvino` only currently supports"
79-
"the tensorflow and openvino backends."
79+
"the tensorflow, jax, torch and openvino backends."
8080
),
8181
)
82+
@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI")
8283
@pytest.mark.skipif(
8384
testing.tensorflow_uses_gpu(), reason="Leads to core dumps on CI"
8485
)

0 commit comments

Comments
 (0)