JAX v0.7.0
- 
New features:
- Added 
jax.Pwhich is an alias forjax.sharding.PartitionSpec. - Added 
jax.tree.reduce_associative. 
 - Added 
 - 
Breaking changes:
- JAX is migrating from GSPMD to Shardy by default. See the
migration guide
for more information. - JAX autodiff is switching to using direct linearization by default (instead of
implementing linearization via JVP and partial eval).
See migration guide
for more information. jax.stages.OutInfohas been replaced withjax.ShapeDtypeStruct.jax.jitnow requiresfunto be passed by position, and additional
arguments to be passed by keyword. Doing otherwise will result in an error
starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.- The minimum Python version is now 3.11. 3.11 will remain the minimum
supported version until July 2026. - Layout API renames:
Layout,.layout,.input_layoutsand.output_layoutshave been
renamed toFormat,.format,.input_formatsand.output_formatsDeviceLocalLayout,.device_local_layouthave been renamed toLayout
and.layout
 jax.experimental.shardmodule has been deleted and all the APIs have been
moved to thejax.shardingendpoint. So usejax.sharding.reshard,
jax.sharding.auto_axesandjax.sharding.explicit_axesinstead of their
experimental endpoints.lax.infeedandlax.outfeedwere removed, after being deprecated in
JAX 0.6. Thetransfer_to_infeedandtransfer_from_outfeedmethods were
also removed theDeviceobjects.- The 
jax.extend.core.primitives.pjit_pprimitive has been renamed to
jit_p, and itsnameattribute has changed from"pjit"to"jit".
This affects the string representations of jaxprs. The same primitive is no
longer exported from thejax.experimental.pjitmodule. - The (undocumented) function 
jax.extend.backend.add_clear_backends_callback
has been removed. Users should usejax.extend.backend.register_backend_cache
instead. 
 - JAX is migrating from GSPMD to Shardy by default. See the
 - 
Deprecations:
- {obj}
jax.dlpack.SUPPORTED_DTYPESis deprecated; please use the new
jax.dlpack.is_supported_dtypefunction. jax.scipy.special.sph_harmhas been deprecated following a similar
deprecation in SciPy; usejax.scipy.special.sph_harm_yinstead.- From {mod}
jax.interpreters.xla, the previously deprecated symbols
abstractifyandpytype_aval_mappingshave been removed. jax.interpreters.xla.canonicalize_dtypeis deprecated. For
canonicalizing dtypes, preferjax.dtypes.canonicalize_dtype.
For checking whether an object is a valid jax input, prefer
jax.core.valid_jaxtype.- From {mod}
jax.core, the previously deprecated symbolsAxisName,
ConcretizationTypeError,axis_frame,call_p,closed_call_p,
get_type,trace_state_clean,typematch, andtypecheckhave been
removed. - From {mod}
jax.lib.xla_client, the previously deprecated symbols
DeviceAssignment,get_topology_for_devices, andmlir_api_version
have been removed. jax.extend.ffiwas removed after being deprecated in v0.5.0.
Use {mod}jax.ffiinstead.jax.lib.xla_bridge.get_compile_optionsis deprecated, and replaced by
jax.extend.backend.get_compile_options.
 - {obj}