From a2c54cd5f81c1f667d0d6ca2c753fdc24bd6a152 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 27 Aug 2025 16:40:39 -0700 Subject: [PATCH] Reverts 402a78661747fc1fde0bc37200996b26eea1c9da PiperOrigin-RevId: 800219883 --- jax/_src/lax/control_flow/loops.py | 4 ---- jax/_src/lax/lax.py | 4 ++-- jax/_src/lax/linalg.py | 3 ++- jax/_src/lax/slicing.py | 24 ++++++------------------ jax/_src/lax/utils.py | 11 ++++------- tests/memories_test.py | 17 ++++------------- 6 files changed, 18 insertions(+), 45 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 6def17df06ed..3286adbb688c 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -540,10 +540,6 @@ def _empty_array(prefix, length_spec, aval): # empty = lax.empty((*prefix, *aval.shape), aval.dtype, out_sharding=sharding) # return core.pvary(empty, tuple(aval.vma)) empty = core.pvary(lax.empty2(aval.dtype), tuple(aval.vma)) - # TODO(yashkatariya): Make this more general by passing aval.memory_space to - # lax.broadcast and then remove this hack? - if aval.memory_space != core.typeof(empty).memory_space: - empty = api.device_put(empty, aval.memory_space) return lax.broadcast(empty, (*prefix, *aval.shape), out_sharding=sharding) eval_jaxpr_p = core.Primitive('eval_jaxpr') diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f7a006976177..808f1cb750f5 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4966,7 +4966,7 @@ def _convert_element_type_bind_with_trace(trace, args, params): _convert_element_type_weak_type_rule, _convert_element_type_sharding_rule, partial(core.standard_vma_rule, convert_element_type_p.name), - None, None)) + None)) ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule @@ -6679,7 +6679,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, sharding=sharding) new_vma = core.standard_vma_rule('broadcast_in_dim', x) return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding, - vma=new_vma, memory_space=x.memory_space) + vma=new_vma) # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray # (even if x is a ShapedArray) # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 62eedaacd8cc..54e5a1bfeac4 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -775,7 +775,8 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, prim.def_abstract_eval( partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule, lax_utils._standard_weak_type_rule, sharding_rule, - partial(core.standard_vma_rule, name), None, None)) + partial(core.standard_vma_rule, name), + None)) if supports_batching: batching.primitive_batchers[prim] = partial( batching.expand_dims_batcher, prim) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 3f1ba26b374e..4a8a54b2e368 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -2694,12 +2694,6 @@ def _scatter_spec_computation( return None -def _scatter_memory_space_rule( - operand, indices, updates, *, update_jaxpr, update_consts, - dimension_numbers, indices_are_sorted, unique_indices, mode): - return operand.memory_space - - def _scatter_sharding_rule( operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): @@ -2911,8 +2905,7 @@ def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *, scatter_add_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-add', weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule, - vma_rule=partial(core.standard_vma_rule, 'scatter_add'), - memory_space_rule=_scatter_memory_space_rule) + vma_rule=partial(core.standard_vma_rule, 'scatter_add')) ad.primitive_jvps[scatter_add_p] = partial(_scatter_addsub_jvp, scatter_add_p) ad.primitive_transposes[scatter_add_p] = partial(_scatter_addsub_transpose_rule, scatter_add_p) batching.fancy_primitive_batchers[scatter_add_p] = partial(_scatter_batching_rule, scatter_add_p) @@ -2921,8 +2914,7 @@ def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *, scatter_sub_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-sub', weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule, - vma_rule=partial(core.standard_vma_rule, 'scatter_sub'), - memory_space_rule=_scatter_memory_space_rule + vma_rule=partial(core.standard_vma_rule, 'scatter_sub') ) ad.primitive_jvps[scatter_sub_p] = partial(_scatter_addsub_jvp, scatter_sub_p) ad.primitive_transposes[scatter_sub_p] = partial(_scatter_addsub_transpose_rule, scatter_sub_p) @@ -2933,8 +2925,7 @@ def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *, scatter_mul_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul', weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule, - vma_rule=partial(core.standard_vma_rule, 'scatter_mul'), - memory_space_rule=_scatter_memory_space_rule) + vma_rule=partial(core.standard_vma_rule, 'scatter_mul')) def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers, indices_are_sorted, unique_indices, mode, **kw): @@ -3065,8 +3056,7 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, scatter_min_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-min', weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule, - vma_rule=partial(core.standard_vma_rule, 'scatter_min'), - memory_space_rule=_scatter_memory_space_rule) + vma_rule=partial(core.standard_vma_rule, 'scatter_min')) batching.fancy_primitive_batchers[scatter_min_p] = ( partial(_scatter_batching_rule, scatter_min_p)) batching.skippable_batchers[scatter_min_p] = lambda _: () @@ -3075,8 +3065,7 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, scatter_max_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-max', weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule, - vma_rule=partial(core.standard_vma_rule, 'scatter_max'), - memory_space_rule=_scatter_memory_space_rule) + vma_rule=partial(core.standard_vma_rule, 'scatter_max')) batching.fancy_primitive_batchers[scatter_max_p] = ( partial(_scatter_batching_rule, scatter_max_p)) batching.skippable_batchers[scatter_max_p] = lambda _: () @@ -3236,8 +3225,7 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, scatter_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter', weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule, - vma_rule=partial(core.standard_vma_rule, 'scatter'), - memory_space_rule=_scatter_memory_space_rule) + vma_rule=partial(core.standard_vma_rule, 'scatter')) ad.primitive_jvps[scatter_p] = _scatter_jvp ad.primitive_transposes[scatter_p] = _scatter_transpose_rule batching.fancy_primitive_batchers[scatter_p] = ( diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 2b61d2a4e986..5471f1b5b571 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -38,14 +38,13 @@ def _argnum_weak_type(*argnums): def standard_primitive(shape_rule, dtype_rule, name, weak_type_rule=None, sharding_rule=None, vma_rule=None, - unreduced_rule=None, memory_space_rule=None): + unreduced_rule=None): weak_type_rule = weak_type_rule or _standard_weak_type_rule prim = core.Primitive(name) prim.def_impl(partial(dispatch.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, - weak_type_rule, sharding_rule, vma_rule, unreduced_rule, - memory_space_rule)) + weak_type_rule, sharding_rule, vma_rule, unreduced_rule)) return prim def _get_array_abstraction_level(a): return a.array_abstraction_level @@ -141,7 +140,7 @@ def multi_mem_space_rule(prim, num_out, *avals, **kwargs): def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, vma_rule, unreduced_rule, - memory_space_rule, *avals, **kwargs): + *avals, **kwargs): for a in avals: if isinstance(a, state.AbstractRef): raise ValueError(f'Attempting to pass a Ref {a} to a primitive: ' @@ -159,9 +158,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, prim, shape_rule, dtype_rule, sharding_rule, unreduced_rule, False, *avals, **kwargs) out_vma = vma_rule(*avals, **kwargs) - out_mem_space = (_default_memory_space_rule(prim, *avals, **kwargs) - if memory_space_rule is None else - memory_space_rule(*avals, **kwargs)) + out_mem_space = _default_memory_space_rule(prim, *avals, **kwargs) out_aval = core.ShapedArray( out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding, vma=out_vma, memory_space=out_mem_space) diff --git a/tests/memories_test.py b/tests/memories_test.py index 1213d52a7114..671f237ba702 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1644,19 +1644,6 @@ def test_fn(x_in, y_in): self.assertArraysEqual(x_out, x1 * x1) self.assertArraysEqual(y_out, y1 + y1) - def test_indexing_on_host(self): - @jax.jit - @compute_on("device_host") - def fn2(x): - x = jax.device_put(x, jax.memory.Space.Host) - y = jnp.ones((2, 1, 4)) - y = jax.device_put(y, jax.memory.Space.Host) - z = x.at[:, 1:2, :].set(y) - return z - - x_host = jax.device_put(jnp.ones((2,3,4)), jax.memory.Space.Host) - fn2(x_host) # doesn't crash - def test_compute_on_cache_miss(self): @jax.jit def f(x): @@ -1924,6 +1911,8 @@ def g(ys, _): compiled_text = compiled_f.as_text() if compiled_text is not None: self.assertIn('S(5)', compiled_text) + self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)") + self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)") compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: @@ -1961,6 +1950,8 @@ def g(ys, _): compiled_text = compiled_f.as_text() if compiled_text is not None: self.assertIn('S(5)', compiled_text) + self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)") + self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)") self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)") self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)") self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)")