diff --git a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py index e1094a43c19d..8921da01d971 100644 --- a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py +++ b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py @@ -1184,7 +1184,55 @@ def _compute_transformed_shape_and_dtype(shape, dtype, transforms): dtype = transform.transform_dtype(dtype) return shape, dtype -def _device_coords_to_logical_id(device_coords, axis_sizes): +# TODO(sharadmv): De-dup this w/ the impl in primitives.py. +def _device_id_dict_to_mesh(device_id_dict, axis_sizes, axis_indices): + physical_axis_dict = {} + axis_names = axis_sizes.keys() + for axis, idx in device_id_dict.items(): + if isinstance(axis, tuple) and any(a in axis_names for a in axis): + if not all(a in axis_names for a in axis): + raise NotImplementedError( + f"{axis} mixes JAX mesh and Pallas mesh grid axes" + ) + axes_dimensions = [axis_sizes[name] for name in axis] + for axis_index, axis_name in enumerate(axis): + axis_size = axis_sizes[axis_name] + inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :]) + minor_divisor = inner_mesh_size + + # Fast path for power of 2s + if inner_mesh_size & (inner_mesh_size - 1) == 0: + shift_len = (inner_mesh_size & -inner_mesh_size).bit_length() - 1 + partial_device_idx = idx >> shift_len + else: + partial_device_idx = idx // minor_divisor + + if axis_size & (axis_size - 1) == 0: + device_idx = partial_device_idx & (axis_size - 1) + else: + device_idx = partial_device_idx % axis_size + physical_axis_dict[axis_name] = device_idx + else: + physical_axis_dict[axis] = idx + device_id = [] + for axis in axis_names: + if axis in physical_axis_dict: + device_id.append(physical_axis_dict[axis]) + else: + device_id.append(axis_indices[axis]) + non_mesh_axes = { + k: v + for k, v in physical_axis_dict.items() + if k not in axis_names + } + return tuple(device_id), non_mesh_axes + +def _device_coords_to_logical_id(device_coords, axis_sizes, axis_indices): + if isinstance(device_coords, dict): + device_coords, non_mesh_axes = _device_id_dict_to_mesh( + device_coords, axis_sizes, axis_indices) + if non_mesh_axes: + raise NotImplementedError(non_mesh_axes) if not isinstance(device_coords, tuple): device_coords = (device_coords,) assert len(device_coords) == len(axis_sizes) @@ -1194,11 +1242,12 @@ def _device_coords_to_logical_id(device_coords, axis_sizes): ret += device_coords[i] * math.prod(sizes[i+1:]) return ret -def _device_id_to_logical(device_id, device_id_type, axis_sizes): +def _device_id_to_logical(device_id, device_id_type, axis_sizes, + axis_indices): if device_id is None: return None if device_id_type == primitives.DeviceIdType.MESH: - return _device_coords_to_logical_id(device_id, axis_sizes) + return _device_coords_to_logical_id(device_id, axis_sizes, axis_indices) elif device_id_type == primitives.DeviceIdType.LOGICAL: return device_id else: @@ -1515,7 +1564,8 @@ def f(*args, jaxpr): target_device_id, ) = jax.tree.unflatten(eqn.params['tree'], deferred_invals()) target_device_id = _device_id_to_logical( - target_device_id, eqn.params['device_id_type'], axis_sizes) + target_device_id, eqn.params['device_id_type'], axis_sizes, + axis_indices) (orig_src_ref, _, orig_dst_ref, *_ ) = jax.tree.unflatten(eqn.params['tree'], eqn.invars) src_memory_space = getattr(orig_src_ref.aval, 'memory_space', None) @@ -1580,7 +1630,8 @@ def f(*args, jaxpr): sem, sem_transforms, inc, target_device_id, core_index = ( jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) target_device_id = _device_id_to_logical( - target_device_id, eqn.params['device_id_type'], axis_sizes) + target_device_id, eqn.params['device_id_type'], axis_sizes, + axis_indices) callback.io_callback( semaphore_signal, (), @@ -1984,7 +2035,7 @@ def interpret_pallas_call( jnp.multiply, axis_sizes.values(), jnp.int32(1)) axis_indices = {k: lax.axis_index(k) for k in axis_sizes.keys()} device_id = _device_coords_to_logical_id( - tuple(axis_indices.values()), axis_sizes) + tuple(axis_indices.values()), axis_sizes, axis_indices) callback.io_callback( functools.partial( _initialize_shared_memory, interpret_params=interpret_params