Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* `arr.view(dtype=None)` now returns the array unchanged, matching NumPy's
semantics. Previously it returned the array with a float dtype.

* Deprecations:
* The parameters `enable_xla` and `native_serialization` for `jax2tf.convert`
are deprecated and will be removed in a future versionof JAX. These were
used for jax2tf with non-native serialization, which has been now removed.

## JAX 0.7.1 (August 20, 2025)

* New features
Expand Down
1 change: 0 additions & 1 deletion jax/experimental/jax2tf/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ py_library(
name = "jax2tf_internal",
srcs = [
"call_tf.py",
"impl_no_xla.py",
"jax2tf.py",
],
# TODO: b/255503696: enable pytype
Expand Down
5 changes: 1 addition & 4 deletions jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,7 @@ def check_tf_result(idx: int, r_tf: TfVal, r_aval: core.ShapedArray | None) -> T
# that tf.ensure_shape did this, but it can only take shapes that contain None
# not computed shapes. However, in eager mode we should be able to resolve
# the declared shapes to constants and we get better checking.
if tf.executing_eagerly():
r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape)
else:
r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
# We do as much checking as we can here, instead of relying on tf.ensure_shape
# because the latter gives different errors in eager vs. compiled mode.
# TODO(b/279454591): This strange error is from TF. Eager function suppose
Expand Down
6 changes: 1 addition & 5 deletions jax/experimental/jax2tf/examples/saved_model_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def convert_and_save_model(
input_signatures: Sequence[tf.TensorSpec],
polymorphic_shapes: str | None = None,
with_gradient: bool = False,
enable_xla: bool = True,
compile_model: bool = True,
saved_model_options: tf.saved_model.SaveOptions | None = None):
"""Convert a JAX function and saves a SavedModel.
Expand Down Expand Up @@ -80,8 +79,6 @@ def convert_and_save_model(
corresponding input shapes.
with_gradient: the value to use for the `with_gradient` parameter for
`jax2tf.convert`.
enable_xla: the value to use for the `enable_xla` parameter for
`jax2tf.convert`.
compile_model: use TensorFlow jit_compiler on the SavedModel. This
is needed if the SavedModel will be used for TensorFlow serving.
polymorphic_shapes: if given then it will be used as the
Expand All @@ -99,8 +96,7 @@ def convert_and_save_model(
tf_fn = jax2tf.convert(
jax_fn,
with_gradient=with_gradient,
polymorphic_shapes=[None, polymorphic_shapes],
enable_xla=enable_xla)
polymorphic_shapes=[None, polymorphic_shapes])

# Create tf.Variables for the parameters. If you want more useful variable
# names, you can use `tree.map_structure_with_path` from the `dm-tree` package
Expand Down
205 changes: 0 additions & 205 deletions jax/experimental/jax2tf/g3doc/no_xla_limitations.md

This file was deleted.

Loading
Loading