Skip to content

Commit d3b1b39

Browse files
gneculaGoogle-ML-Automation
authored andcommitted
[jax2tf] Remove most of the code for jax2tf with non-native serialization
We keep only the `enable_xla` and `native_serialization` parameters to `jax2tf.convert`. They are now deprecated and will be removed in a future version of JAX. PiperOrigin-RevId: 800466991
1 parent 9495be7 commit d3b1b39

13 files changed

+43
-7604
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2929
`False`, but note this is a temporary config that will be removed in a
3030
future release.
3131

32+
* Deprecations:
33+
* The parameters `enable_xla` and `native_serialization` for `jax2tf.convert`
34+
are deprecated and will be removed in a future versionof JAX. These were
35+
used for jax2tf with non-native serialization, which has been now removed.
36+
3237
## JAX 0.7.1 (August 20, 2025)
3338

3439
* New features

jax/experimental/jax2tf/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ py_library(
3838
name = "jax2tf_internal",
3939
srcs = [
4040
"call_tf.py",
41-
"impl_no_xla.py",
4241
"jax2tf.py",
4342
],
4443
# TODO: b/255503696: enable pytype

jax/experimental/jax2tf/call_tf.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,7 @@ def check_tf_result(idx: int, r_tf: TfVal, r_aval: core.ShapedArray | None) -> T
303303
# that tf.ensure_shape did this, but it can only take shapes that contain None
304304
# not computed shapes. However, in eager mode we should be able to resolve
305305
# the declared shapes to constants and we get better checking.
306-
if tf.executing_eagerly():
307-
r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape)
308-
else:
309-
r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
306+
r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
310307
# We do as much checking as we can here, instead of relying on tf.ensure_shape
311308
# because the latter gives different errors in eager vs. compiled mode.
312309
# TODO(b/279454591): This strange error is from TF. Eager function suppose
@@ -646,16 +643,6 @@ def _register_call_lowering(platform):
646643
for platform in ("cpu", "cuda", "tpu"):
647644
_register_call_lowering(platform)
648645

649-
# Support the call_tf under jax2tf.convert in eager mode
650-
def _jax2tf_call_tf(*args: TfVal,
651-
callable_flat_tf: Callable,
652-
**_) -> TfVal:
653-
with jax2tf_internal.inside_call_tf():
654-
res_tf_flat = callable_flat_tf(*args)
655-
return res_tf_flat
656-
657-
jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf
658-
659646

660647
def emit_tf_embedded_graph_custom_call(
661648
ctx: mlir.LoweringRuleContext,

jax/experimental/jax2tf/examples/saved_model_lib.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def convert_and_save_model(
4141
input_signatures: Sequence[tf.TensorSpec],
4242
polymorphic_shapes: str | None = None,
4343
with_gradient: bool = False,
44-
enable_xla: bool = True,
4544
compile_model: bool = True,
4645
saved_model_options: tf.saved_model.SaveOptions | None = None):
4746
"""Convert a JAX function and saves a SavedModel.
@@ -80,8 +79,6 @@ def convert_and_save_model(
8079
corresponding input shapes.
8180
with_gradient: the value to use for the `with_gradient` parameter for
8281
`jax2tf.convert`.
83-
enable_xla: the value to use for the `enable_xla` parameter for
84-
`jax2tf.convert`.
8582
compile_model: use TensorFlow jit_compiler on the SavedModel. This
8683
is needed if the SavedModel will be used for TensorFlow serving.
8784
polymorphic_shapes: if given then it will be used as the
@@ -99,8 +96,7 @@ def convert_and_save_model(
9996
tf_fn = jax2tf.convert(
10097
jax_fn,
10198
with_gradient=with_gradient,
102-
polymorphic_shapes=[None, polymorphic_shapes],
103-
enable_xla=enable_xla)
99+
polymorphic_shapes=[None, polymorphic_shapes])
104100

105101
# Create tf.Variables for the parameters. If you want more useful variable
106102
# names, you can use `tree.map_structure_with_path` from the `dm-tree` package

jax/experimental/jax2tf/g3doc/no_xla_limitations.md

Lines changed: 0 additions & 205 deletions
This file was deleted.

0 commit comments

Comments
 (0)