Skip to content

Commit 858efa0

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add an impl rule to empty2. Fixes #32404
PiperOrigin-RevId: 816219967
1 parent e6bc889 commit 858efa0

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

jax/_src/lax/lax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9090,6 +9090,7 @@ def _empty_batcher(axis_data, vals_in, dims_in, *, shape, dtype, out_sharding):
90909090
def empty2(dtype, *, memory_space):
90919091
return empty2_p.bind(dtype=dtype, memory_space=memory_space)
90929092
empty2_p = core.Primitive('empty2')
9093+
dispatch.simple_impl(empty2_p)
90939094

90949095
def _empty2_abstract_eval(*, dtype, memory_space):
90959096
return core.ShapedArray((), dtype, memory_space=memory_space)

tests/python_callback_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,17 @@ def f(x):
604604
x = np.arange(8, dtype=dtype)
605605
np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype))
606606

607+
def test_pure_callback_sequential_vmap_method_eval_jaxpr(self):
608+
def f(x):
609+
return jax.pure_callback(
610+
lambda x: x, jax.ShapeDtypeStruct(shape=(), dtype=jnp.float32),
611+
x, vmap_method="sequential")
612+
613+
jaxpr = jax.make_jaxpr(lambda: jax.vmap(f)(
614+
jnp.zeros(100, dtype=jnp.float32)))()
615+
with jax.ensure_compile_time_eval():
616+
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) # doesn't crash
617+
607618
@parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn")
608619
def test_subbyte_results(self, dtype: str):
609620
if "2" in dtype and jtu.test_device_matches(["tpu"]):

0 commit comments

Comments
 (0)