Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,6 @@ def _empty_array(prefix, length_spec, aval):
# empty = lax.empty((*prefix, *aval.shape), aval.dtype, out_sharding=sharding)
# return core.pvary(empty, tuple(aval.vma))
empty = core.pvary(lax.empty2(aval.dtype), tuple(aval.vma))
# TODO(yashkatariya): Make this more general by passing aval.memory_space to
# lax.broadcast and then remove this hack?
if aval.memory_space != core.typeof(empty).memory_space:
empty = api.device_put(empty, aval.memory_space)
return lax.broadcast(empty, (*prefix, *aval.shape), out_sharding=sharding)

eval_jaxpr_p = core.Primitive('eval_jaxpr')
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4966,7 +4966,7 @@ def _convert_element_type_bind_with_trace(trace, args, params):
_convert_element_type_weak_type_rule,
_convert_element_type_sharding_rule,
partial(core.standard_vma_rule, convert_element_type_p.name),
None, None))
None))
ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule)
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule

Expand Down Expand Up @@ -6679,7 +6679,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions,
sharding=sharding)
new_vma = core.standard_vma_rule('broadcast_in_dim', x)
return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding,
vma=new_vma, memory_space=x.memory_space)
vma=new_vma)
# If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray
# (even if x is a ShapedArray)
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,8 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name,
prim.def_abstract_eval(
partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule,
lax_utils._standard_weak_type_rule, sharding_rule,
partial(core.standard_vma_rule, name), None, None))
partial(core.standard_vma_rule, name),
None))
if supports_batching:
batching.primitive_batchers[prim] = partial(
batching.expand_dims_batcher, prim)
Expand Down
24 changes: 6 additions & 18 deletions jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2694,12 +2694,6 @@ def _scatter_spec_computation(
return None


def _scatter_memory_space_rule(
operand, indices, updates, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices, mode):
return operand.memory_space


def _scatter_sharding_rule(
operand, indices, updates, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices, mode):
Expand Down Expand Up @@ -2911,8 +2905,7 @@ def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *,
scatter_add_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule,
vma_rule=partial(core.standard_vma_rule, 'scatter_add'),
memory_space_rule=_scatter_memory_space_rule)
vma_rule=partial(core.standard_vma_rule, 'scatter_add'))
ad.primitive_jvps[scatter_add_p] = partial(_scatter_addsub_jvp, scatter_add_p)
ad.primitive_transposes[scatter_add_p] = partial(_scatter_addsub_transpose_rule, scatter_add_p)
batching.fancy_primitive_batchers[scatter_add_p] = partial(_scatter_batching_rule, scatter_add_p)
Expand All @@ -2921,8 +2914,7 @@ def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *,
scatter_sub_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-sub',
weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule,
vma_rule=partial(core.standard_vma_rule, 'scatter_sub'),
memory_space_rule=_scatter_memory_space_rule
vma_rule=partial(core.standard_vma_rule, 'scatter_sub')
)
ad.primitive_jvps[scatter_sub_p] = partial(_scatter_addsub_jvp, scatter_sub_p)
ad.primitive_transposes[scatter_sub_p] = partial(_scatter_addsub_transpose_rule, scatter_sub_p)
Expand All @@ -2933,8 +2925,7 @@ def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *,
scatter_mul_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul',
weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule,
vma_rule=partial(core.standard_vma_rule, 'scatter_mul'),
memory_space_rule=_scatter_memory_space_rule)
vma_rule=partial(core.standard_vma_rule, 'scatter_mul'))

def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers,
indices_are_sorted, unique_indices, mode, **kw):
Expand Down Expand Up @@ -3065,8 +3056,7 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
scatter_min_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule,
vma_rule=partial(core.standard_vma_rule, 'scatter_min'),
memory_space_rule=_scatter_memory_space_rule)
vma_rule=partial(core.standard_vma_rule, 'scatter_min'))
batching.fancy_primitive_batchers[scatter_min_p] = (
partial(_scatter_batching_rule, scatter_min_p))
batching.skippable_batchers[scatter_min_p] = lambda _: ()
Expand All @@ -3075,8 +3065,7 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
scatter_max_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule,
vma_rule=partial(core.standard_vma_rule, 'scatter_max'),
memory_space_rule=_scatter_memory_space_rule)
vma_rule=partial(core.standard_vma_rule, 'scatter_max'))
batching.fancy_primitive_batchers[scatter_max_p] = (
partial(_scatter_batching_rule, scatter_max_p))
batching.skippable_batchers[scatter_max_p] = lambda _: ()
Expand Down Expand Up @@ -3236,8 +3225,7 @@ def _scatter_transpose_rule(t, operand, indices, updates, *,
scatter_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter',
weak_type_rule=_argnum_weak_type(0), sharding_rule=_scatter_sharding_rule,
vma_rule=partial(core.standard_vma_rule, 'scatter'),
memory_space_rule=_scatter_memory_space_rule)
vma_rule=partial(core.standard_vma_rule, 'scatter'))
ad.primitive_jvps[scatter_p] = _scatter_jvp
ad.primitive_transposes[scatter_p] = _scatter_transpose_rule
batching.fancy_primitive_batchers[scatter_p] = (
Expand Down
11 changes: 4 additions & 7 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@ def _argnum_weak_type(*argnums):

def standard_primitive(shape_rule, dtype_rule, name,
weak_type_rule=None, sharding_rule=None, vma_rule=None,
unreduced_rule=None, memory_space_rule=None):
unreduced_rule=None):
weak_type_rule = weak_type_rule or _standard_weak_type_rule
prim = core.Primitive(name)
prim.def_impl(partial(dispatch.apply_primitive, prim))
prim.def_abstract_eval(
partial(standard_abstract_eval, prim, shape_rule, dtype_rule,
weak_type_rule, sharding_rule, vma_rule, unreduced_rule,
memory_space_rule))
weak_type_rule, sharding_rule, vma_rule, unreduced_rule))
return prim

def _get_array_abstraction_level(a): return a.array_abstraction_level
Expand Down Expand Up @@ -141,7 +140,7 @@ def multi_mem_space_rule(prim, num_out, *avals, **kwargs):

def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
sharding_rule, vma_rule, unreduced_rule,
memory_space_rule, *avals, **kwargs):
*avals, **kwargs):
for a in avals:
if isinstance(a, state.AbstractRef):
raise ValueError(f'Attempting to pass a Ref {a} to a primitive: '
Expand All @@ -159,9 +158,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
prim, shape_rule, dtype_rule, sharding_rule, unreduced_rule, False,
*avals, **kwargs)
out_vma = vma_rule(*avals, **kwargs)
out_mem_space = (_default_memory_space_rule(prim, *avals, **kwargs)
if memory_space_rule is None else
memory_space_rule(*avals, **kwargs))
out_mem_space = _default_memory_space_rule(prim, *avals, **kwargs)
out_aval = core.ShapedArray(
out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding,
vma=out_vma, memory_space=out_mem_space)
Expand Down
17 changes: 4 additions & 13 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,19 +1644,6 @@ def test_fn(x_in, y_in):
self.assertArraysEqual(x_out, x1 * x1)
self.assertArraysEqual(y_out, y1 + y1)

def test_indexing_on_host(self):
@jax.jit
@compute_on("device_host")
def fn2(x):
x = jax.device_put(x, jax.memory.Space.Host)
y = jnp.ones((2, 1, 4))
y = jax.device_put(y, jax.memory.Space.Host)
z = x.at[:, 1:2, :].set(y)
return z

x_host = jax.device_put(jnp.ones((2,3,4)), jax.memory.Space.Host)
fn2(x_host) # doesn't crash

def test_compute_on_cache_miss(self):
@jax.jit
def f(x):
Expand Down Expand Up @@ -1924,6 +1911,8 @@ def g(ys, _):
compiled_text = compiled_f.as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")

compiled_stats = compiled_f.memory_analysis()
if compiled_stats is not None:
Expand Down Expand Up @@ -1961,6 +1950,8 @@ def g(ys, _):
compiled_text = compiled_f.as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)")
Expand Down
Loading