Skip to content

Commit 895fc99

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Reverts 6a0d8a4
PiperOrigin-RevId: 800597204
1 parent a27eb30 commit 895fc99

File tree

6 files changed

+47
-19
lines changed

6 files changed

+47
-19
lines changed

jax/_src/lax/control_flow/loops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,12 @@ def _empty_array(prefix, length_spec, aval):
540540
# empty = lax.empty((*prefix, *aval.shape), aval.dtype, out_sharding=sharding)
541541
# return core.pvary(empty, tuple(aval.vma))
542542
empty = core.pvary(lax.empty2(aval.dtype), tuple(aval.vma))
543-
return lax.broadcast(empty, (*prefix, *aval.shape), out_sharding=sharding)
543+
out = lax.broadcast(empty, (*prefix, *aval.shape), out_sharding=sharding)
544+
# TODO(yashkatariya): Maybe make this more general by passing
545+
# aval.memory_space to lax.broadcast and then remove this hack?
546+
if aval.memory_space != core.typeof(out).memory_space:
547+
out = api.device_put(out, aval.memory_space)
548+
return out
544549

545550
eval_jaxpr_p = core.Primitive('eval_jaxpr')
546551
eval_jaxpr_p.multiple_results = True

jax/_src/lax/lax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4967,7 +4967,7 @@ def _convert_element_type_bind_with_trace(trace, args, params):
49674967
_convert_element_type_weak_type_rule,
49684968
_convert_element_type_sharding_rule,
49694969
partial(core.standard_vma_rule, convert_element_type_p.name),
4970-
None))
4970+
None, None))
49714971
ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule)
49724972
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
49734973

@@ -6680,7 +6680,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions,
66806680
sharding=sharding)
66816681
new_vma = core.standard_vma_rule('broadcast_in_dim', x)
66826682
return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding,
6683-
vma=new_vma)
6683+
vma=new_vma, memory_space=x.memory_space)
66846684
# If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray
66856685
# (even if x is a ShapedArray)
66866686
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code

jax/_src/lax/linalg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -775,8 +775,7 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name,
775775
prim.def_abstract_eval(
776776
partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule,
777777
lax_utils._standard_weak_type_rule, sharding_rule,
778-
partial(core.standard_vma_rule, name),
779-
None))
778+
partial(core.standard_vma_rule, name), None, None))
780779
if supports_batching:
781780
batching.primitive_batchers[prim] = partial(
782781
batching.expand_dims_batcher, prim)

jax/_src/lax/slicing.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2694,6 +2694,12 @@ def _scatter_spec_computation(
26942694
return None
26952695

26962696

2697+
def _scatter_memory_space_rule(
2698+
operand, indices, updates, *, update_jaxpr, update_consts,
2699+
dimension_numbers, indices_are_sorted, unique_indices, mode):
2700+
return operand.memory_space
2701+
2702+
26972703
def _scatter_sharding_rule(
26982704
operand, indices, updates, *, update_jaxpr, update_consts,
26992705
dimension_numbers, indices_are_sorted, unique_indices, mode):
@@ -2905,7 +2911,8 @@ def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *,
29052911
scatter_add_p = standard_primitive(
29062912
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
29072913
weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule,
2908-
vma_rule=partial(core.standard_vma_rule, 'scatter_add'))
2914+
vma_rule=partial(core.standard_vma_rule, 'scatter_add'),
2915+
memory_space_rule=_scatter_memory_space_rule)
29092916
ad.primitive_jvps[scatter_add_p] = partial(_scatter_addsub_jvp, scatter_add_p)
29102917
ad.primitive_transposes[scatter_add_p] = partial(_scatter_addsub_transpose_rule, scatter_add_p)
29112918
batching.fancy_primitive_batchers[scatter_add_p] = partial(_scatter_batching_rule, scatter_add_p)
@@ -2914,7 +2921,8 @@ def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *,
29142921
scatter_sub_p = standard_primitive(
29152922
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-sub',
29162923
weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule,
2917-
vma_rule=partial(core.standard_vma_rule, 'scatter_sub')
2924+
vma_rule=partial(core.standard_vma_rule, 'scatter_sub'),
2925+
memory_space_rule=_scatter_memory_space_rule
29182926
)
29192927
ad.primitive_jvps[scatter_sub_p] = partial(_scatter_addsub_jvp, scatter_sub_p)
29202928
ad.primitive_transposes[scatter_sub_p] = partial(_scatter_addsub_transpose_rule, scatter_sub_p)
@@ -2925,7 +2933,8 @@ def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *,
29252933
scatter_mul_p = standard_primitive(
29262934
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul',
29272935
weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule,
2928-
vma_rule=partial(core.standard_vma_rule, 'scatter_mul'))
2936+
vma_rule=partial(core.standard_vma_rule, 'scatter_mul'),
2937+
memory_space_rule=_scatter_memory_space_rule)
29292938

29302939
def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers,
29312940
indices_are_sorted, unique_indices, mode, **kw):
@@ -3056,7 +3065,8 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
30563065
scatter_min_p = standard_primitive(
30573066
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
30583067
weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule,
3059-
vma_rule=partial(core.standard_vma_rule, 'scatter_min'))
3068+
vma_rule=partial(core.standard_vma_rule, 'scatter_min'),
3069+
memory_space_rule=_scatter_memory_space_rule)
30603070
batching.fancy_primitive_batchers[scatter_min_p] = (
30613071
partial(_scatter_batching_rule, scatter_min_p))
30623072
batching.skippable_batchers[scatter_min_p] = lambda _: ()
@@ -3065,7 +3075,8 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
30653075
scatter_max_p = standard_primitive(
30663076
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
30673077
weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule,
3068-
vma_rule=partial(core.standard_vma_rule, 'scatter_max'))
3078+
vma_rule=partial(core.standard_vma_rule, 'scatter_max'),
3079+
memory_space_rule=_scatter_memory_space_rule)
30693080
batching.fancy_primitive_batchers[scatter_max_p] = (
30703081
partial(_scatter_batching_rule, scatter_max_p))
30713082
batching.skippable_batchers[scatter_max_p] = lambda _: ()
@@ -3225,7 +3236,8 @@ def _scatter_transpose_rule(t, operand, indices, updates, *,
32253236
scatter_p = standard_primitive(
32263237
_scatter_shape_rule, _scatter_dtype_rule, 'scatter',
32273238
weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule,
3228-
vma_rule=partial(core.standard_vma_rule, 'scatter'))
3239+
vma_rule=partial(core.standard_vma_rule, 'scatter'),
3240+
memory_space_rule=_scatter_memory_space_rule)
32293241
ad.primitive_jvps[scatter_p] = _scatter_jvp
32303242
ad.primitive_transposes[scatter_p] = _scatter_transpose_rule
32313243
batching.fancy_primitive_batchers[scatter_p] = (

jax/_src/lax/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,14 @@ def _argnum_weak_type(*argnums):
3838

3939
def standard_primitive(shape_rule, dtype_rule, name,
4040
weak_type_rule=None, sharding_rule=None, vma_rule=None,
41-
unreduced_rule=None):
41+
unreduced_rule=None, memory_space_rule=None):
4242
weak_type_rule = weak_type_rule or _standard_weak_type_rule
4343
prim = core.Primitive(name)
4444
prim.def_impl(partial(dispatch.apply_primitive, prim))
4545
prim.def_abstract_eval(
4646
partial(standard_abstract_eval, prim, shape_rule, dtype_rule,
47-
weak_type_rule, sharding_rule, vma_rule, unreduced_rule))
47+
weak_type_rule, sharding_rule, vma_rule, unreduced_rule,
48+
memory_space_rule))
4849
return prim
4950

5051
def _get_array_abstraction_level(a): return a.array_abstraction_level
@@ -140,7 +141,7 @@ def multi_mem_space_rule(prim, num_out, *avals, **kwargs):
140141

141142
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
142143
sharding_rule, vma_rule, unreduced_rule,
143-
*avals, **kwargs):
144+
memory_space_rule, *avals, **kwargs):
144145
for a in avals:
145146
if isinstance(a, state.AbstractRef):
146147
raise ValueError(f'Attempting to pass a Ref {a} to a primitive: '
@@ -158,7 +159,9 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
158159
prim, shape_rule, dtype_rule, sharding_rule, unreduced_rule, False,
159160
*avals, **kwargs)
160161
out_vma = vma_rule(*avals, **kwargs)
161-
out_mem_space = _default_memory_space_rule(prim, *avals, **kwargs)
162+
out_mem_space = (_default_memory_space_rule(prim, *avals, **kwargs)
163+
if memory_space_rule is None else
164+
memory_space_rule(*avals, **kwargs))
162165
out_aval = core.ShapedArray(
163166
out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding,
164167
vma=out_vma, memory_space=out_mem_space)

tests/memories_test.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,19 @@ def test_fn(x_in, y_in):
16421642
self.assertArraysEqual(x_out, x1 * x1)
16431643
self.assertArraysEqual(y_out, y1 + y1)
16441644

1645+
def test_indexing_on_host(self):
1646+
@jax.jit
1647+
@compute_on("device_host")
1648+
def fn2(x):
1649+
x = jax.device_put(x, jax.memory.Space.Host)
1650+
y = jnp.ones((2, 1, 4))
1651+
y = jax.device_put(y, jax.memory.Space.Host)
1652+
z = x.at[:, 1:2, :].set(y)
1653+
return z
1654+
1655+
x_host = jax.device_put(jnp.ones((2,3,4)), jax.memory.Space.Host)
1656+
fn2(x_host) # doesn't crash
1657+
16451658
def test_compute_on_cache_miss(self):
16461659
@jax.jit
16471660
def f(x):
@@ -1909,8 +1922,6 @@ def g(ys, _):
19091922
compiled_text = compiled_f.as_text()
19101923
if compiled_text is not None:
19111924
self.assertIn('S(5)', compiled_text)
1912-
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
1913-
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
19141925

19151926
compiled_stats = compiled_f.memory_analysis()
19161927
if compiled_stats is not None:
@@ -1948,8 +1959,6 @@ def g(ys, _):
19481959
compiled_text = compiled_f.as_text()
19491960
if compiled_text is not None:
19501961
self.assertIn('S(5)', compiled_text)
1951-
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
1952-
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
19531962
self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)")
19541963
self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)")
19551964
self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)")

0 commit comments

Comments
 (0)