|
| 1 | +import warnings |
| 2 | + |
1 | 3 | from keras.src import backend
|
2 | 4 | from keras.src import tree
|
3 | 5 | from keras.src.export.export_utils import convert_spec_to_tensor
|
@@ -108,10 +110,24 @@ def parameterize_inputs(inputs, prefix=""):
|
108 | 110 | inputs = tree.map_structure(make_tf_tensor_spec, input_signature)
|
109 | 111 | decorated_fn = get_concrete_fn(model, inputs, **kwargs)
|
110 | 112 | ov_model = ov.convert_model(decorated_fn)
|
| 113 | + elif backend.backend() == "torch": |
| 114 | + import torch |
| 115 | + |
| 116 | + sample_inputs = tree.map_structure( |
| 117 | + lambda x: convert_spec_to_tensor(x, replace_none_number=1), |
| 118 | + input_signature, |
| 119 | + ) |
| 120 | + sample_inputs = tuple(sample_inputs) |
| 121 | + if hasattr(model, "eval"): |
| 122 | + model.eval() |
| 123 | + with warnings.catch_warnings(): |
| 124 | + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) |
| 125 | + traced = torch.jit.trace(model, sample_inputs) |
| 126 | + ov_model = ov.convert_model(traced) |
111 | 127 | else:
|
112 | 128 | raise NotImplementedError(
|
113 | 129 | "`export_openvino` is only compatible with OpenVINO, "
|
114 |
| - "TensorFlow and JAX backends." |
| 130 | + "TensorFlow, JAX and Torch backends." |
115 | 131 | )
|
116 | 132 |
|
117 | 133 | ov.serialize(ov_model, filepath)
|
@@ -147,7 +163,6 @@ def _check_jax_kwargs(kwargs):
|
147 | 163 |
|
148 | 164 |
|
149 | 165 | def get_concrete_fn(model, input_signature, **kwargs):
|
150 |
| - """Get the `tf.function` associated with the model.""" |
151 | 166 | if backend.backend() == "jax":
|
152 | 167 | kwargs = _check_jax_kwargs(kwargs)
|
153 | 168 | export_archive = ExportArchive()
|
|
0 commit comments