@@ -2694,6 +2694,12 @@ def _scatter_spec_computation(
26942694 return None
26952695
26962696
2697+ def _scatter_memory_space_rule (
2698+ operand , indices , updates , * , update_jaxpr , update_consts ,
2699+ dimension_numbers , indices_are_sorted , unique_indices , mode ):
2700+ return operand .memory_space
2701+
2702+
26972703def _scatter_sharding_rule (
26982704 operand , indices , updates , * , update_jaxpr , update_consts ,
26992705 dimension_numbers , indices_are_sorted , unique_indices , mode ):
@@ -2905,7 +2911,8 @@ def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *,
29052911scatter_add_p = standard_primitive (
29062912 _scatter_shape_rule , _scatter_dtype_rule , 'scatter-add' ,
29072913 weak_type_rule = _argnum_weak_type (0 ), sharding_rule = _scatter_sharding_rule ,
2908- vma_rule = partial (core .standard_vma_rule , 'scatter_add' ))
2914+ vma_rule = partial (core .standard_vma_rule , 'scatter_add' ),
2915+ memory_space_rule = _scatter_memory_space_rule )
29092916ad .primitive_jvps [scatter_add_p ] = partial (_scatter_addsub_jvp , scatter_add_p )
29102917ad .primitive_transposes [scatter_add_p ] = partial (_scatter_addsub_transpose_rule , scatter_add_p )
29112918batching .fancy_primitive_batchers [scatter_add_p ] = partial (_scatter_batching_rule , scatter_add_p )
@@ -2914,7 +2921,8 @@ def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *,
29142921scatter_sub_p = standard_primitive (
29152922 _scatter_shape_rule , _scatter_dtype_rule , 'scatter-sub' ,
29162923 weak_type_rule = _argnum_weak_type (0 ), sharding_rule = _scatter_sharding_rule ,
2917- vma_rule = partial (core .standard_vma_rule , 'scatter_sub' )
2924+ vma_rule = partial (core .standard_vma_rule , 'scatter_sub' ),
2925+ memory_space_rule = _scatter_memory_space_rule
29182926)
29192927ad .primitive_jvps [scatter_sub_p ] = partial (_scatter_addsub_jvp , scatter_sub_p )
29202928ad .primitive_transposes [scatter_sub_p ] = partial (_scatter_addsub_transpose_rule , scatter_sub_p )
@@ -2925,7 +2933,8 @@ def _scatter_batching_rule(scatter_op, axis_data, batched_args, batch_dims, *,
29252933scatter_mul_p = standard_primitive (
29262934 _scatter_shape_rule , _scatter_dtype_rule , 'scatter-mul' ,
29272935 weak_type_rule = _argnum_weak_type (0 ), sharding_rule = _scatter_sharding_rule ,
2928- vma_rule = partial (core .standard_vma_rule , 'scatter_mul' ))
2936+ vma_rule = partial (core .standard_vma_rule , 'scatter_mul' ),
2937+ memory_space_rule = _scatter_memory_space_rule )
29292938
29302939def _scatter_mul_jvp_rhs (g , x , i , y , * , dimension_numbers ,
29312940 indices_are_sorted , unique_indices , mode , ** kw ):
@@ -3056,7 +3065,8 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
30563065scatter_min_p = standard_primitive (
30573066 _scatter_shape_rule , _scatter_dtype_rule , 'scatter-min' ,
30583067 weak_type_rule = _argnum_weak_type (0 ), sharding_rule = _scatter_sharding_rule ,
3059- vma_rule = partial (core .standard_vma_rule , 'scatter_min' ))
3068+ vma_rule = partial (core .standard_vma_rule , 'scatter_min' ),
3069+ memory_space_rule = _scatter_memory_space_rule )
30603070batching .fancy_primitive_batchers [scatter_min_p ] = (
30613071 partial (_scatter_batching_rule , scatter_min_p ))
30623072batching .skippable_batchers [scatter_min_p ] = lambda _ : ()
@@ -3065,7 +3075,8 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
30653075scatter_max_p = standard_primitive (
30663076 _scatter_shape_rule , _scatter_dtype_rule , 'scatter-max' ,
30673077 weak_type_rule = _argnum_weak_type (0 ), sharding_rule = _scatter_sharding_rule ,
3068- vma_rule = partial (core .standard_vma_rule , 'scatter_max' ))
3078+ vma_rule = partial (core .standard_vma_rule , 'scatter_max' ),
3079+ memory_space_rule = _scatter_memory_space_rule )
30693080batching .fancy_primitive_batchers [scatter_max_p ] = (
30703081 partial (_scatter_batching_rule , scatter_max_p ))
30713082batching .skippable_batchers [scatter_max_p ] = lambda _ : ()
@@ -3225,7 +3236,8 @@ def _scatter_transpose_rule(t, operand, indices, updates, *,
32253236scatter_p = standard_primitive (
32263237 _scatter_shape_rule , _scatter_dtype_rule , 'scatter' ,
32273238 weak_type_rule = _argnum_weak_type (0 ), sharding_rule = _scatter_sharding_rule ,
3228- vma_rule = partial (core .standard_vma_rule , 'scatter' ))
3239+ vma_rule = partial (core .standard_vma_rule , 'scatter' ),
3240+ memory_space_rule = _scatter_memory_space_rule )
32293241ad .primitive_jvps [scatter_p ] = _scatter_jvp
32303242ad .primitive_transposes [scatter_p ] = _scatter_transpose_rule
32313243batching .fancy_primitive_batchers [scatter_p ] = (
0 commit comments