Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,55 @@ def _compute_transformed_shape_and_dtype(shape, dtype, transforms):
dtype = transform.transform_dtype(dtype)
return shape, dtype

def _device_coords_to_logical_id(device_coords, axis_sizes):
# TODO(sharadmv): De-dup this w/ the impl in primitives.py.
def _device_id_dict_to_mesh(device_id_dict, axis_sizes, axis_indices):
physical_axis_dict = {}
axis_names = axis_sizes.keys()
for axis, idx in device_id_dict.items():
if isinstance(axis, tuple) and any(a in axis_names for a in axis):
if not all(a in axis_names for a in axis):
raise NotImplementedError(
f"{axis} mixes JAX mesh and Pallas mesh grid axes"
)
axes_dimensions = [axis_sizes[name] for name in axis]
for axis_index, axis_name in enumerate(axis):
axis_size = axis_sizes[axis_name]
inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :])
minor_divisor = inner_mesh_size

# Fast path for power of 2s
if inner_mesh_size & (inner_mesh_size - 1) == 0:
shift_len = (inner_mesh_size & -inner_mesh_size).bit_length() - 1
partial_device_idx = idx >> shift_len
else:
partial_device_idx = idx // minor_divisor

if axis_size & (axis_size - 1) == 0:
device_idx = partial_device_idx & (axis_size - 1)
else:
device_idx = partial_device_idx % axis_size
physical_axis_dict[axis_name] = device_idx
else:
physical_axis_dict[axis] = idx
device_id = []
for axis in axis_names:
if axis in physical_axis_dict:
device_id.append(physical_axis_dict[axis])
else:
device_id.append(axis_indices[axis])
non_mesh_axes = {
k: v
for k, v in physical_axis_dict.items()
if k not in axis_names
}
return tuple(device_id), non_mesh_axes

def _device_coords_to_logical_id(device_coords, axis_sizes, axis_indices):
if isinstance(device_coords, dict):
device_coords, non_mesh_axes = _device_id_dict_to_mesh(
device_coords, axis_sizes, axis_indices)
if non_mesh_axes:
raise NotImplementedError(non_mesh_axes)
if not isinstance(device_coords, tuple):
device_coords = (device_coords,)
assert len(device_coords) == len(axis_sizes)
Expand All @@ -1194,11 +1242,12 @@ def _device_coords_to_logical_id(device_coords, axis_sizes):
ret += device_coords[i] * math.prod(sizes[i+1:])
return ret

def _device_id_to_logical(device_id, device_id_type, axis_sizes):
def _device_id_to_logical(device_id, device_id_type, axis_sizes,
axis_indices):
if device_id is None:
return None
if device_id_type == primitives.DeviceIdType.MESH:
return _device_coords_to_logical_id(device_id, axis_sizes)
return _device_coords_to_logical_id(device_id, axis_sizes, axis_indices)
elif device_id_type == primitives.DeviceIdType.LOGICAL:
return device_id
else:
Expand Down Expand Up @@ -1515,7 +1564,8 @@ def f(*args, jaxpr):
target_device_id,
) = jax.tree.unflatten(eqn.params['tree'], deferred_invals())
target_device_id = _device_id_to_logical(
target_device_id, eqn.params['device_id_type'], axis_sizes)
target_device_id, eqn.params['device_id_type'], axis_sizes,
axis_indices)
(orig_src_ref, _, orig_dst_ref, *_
) = jax.tree.unflatten(eqn.params['tree'], eqn.invars)
src_memory_space = getattr(orig_src_ref.aval, 'memory_space', None)
Expand Down Expand Up @@ -1580,7 +1630,8 @@ def f(*args, jaxpr):
sem, sem_transforms, inc, target_device_id, core_index = (
jax.tree.unflatten(eqn.params['args_tree'], deferred_invals()))
target_device_id = _device_id_to_logical(
target_device_id, eqn.params['device_id_type'], axis_sizes)
target_device_id, eqn.params['device_id_type'], axis_sizes,
axis_indices)
callback.io_callback(
semaphore_signal,
(),
Expand Down Expand Up @@ -1984,7 +2035,7 @@ def interpret_pallas_call(
jnp.multiply, axis_sizes.values(), jnp.int32(1))
axis_indices = {k: lax.axis_index(k) for k in axis_sizes.keys()}
device_id = _device_coords_to_logical_id(
tuple(axis_indices.values()), axis_sizes)
tuple(axis_indices.values()), axis_sizes, axis_indices)
callback.io_callback(
functools.partial(
_initialize_shared_memory, interpret_params=interpret_params
Expand Down
Loading