Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions jax/_src/mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,19 @@ def _7x_create_device_mesh(
}


def _reverse_combinations(iter, mesh_shape: Sequence[int]):
"""Reverses the indices in the given iterator."""
for elem in iter:
indices, _ = zip(*elem)
yield tuple((len(mesh_shape) - i - 1, mesh_shape[-i - 1]) for i in indices)


def _create_device_mesh_for_nd_torus(
physical_mesh: np.ndarray,
mesh_shape: Sequence[int],
*,
allow_split_physical_axes: bool = False,
device_kind: str | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Assigns logical parallelism axes to physical axes of an N-D torus network.

Expand Down Expand Up @@ -266,6 +274,12 @@ def _create_device_mesh_for_nd_torus(
indices_and_axes = itertools.combinations(
enumerate(assignable_physical_mesh), num_axes
)
if device_kind == _TPU_7X:
# For TPU7x, the innermost physical axis (cores on chip) has higher
# bandwidth. So, prioritize it if assignable.
indices_and_axes = _reverse_combinations(
indices_and_axes, assignable_physical_mesh
)
for elem in indices_and_axes:
c_indices, c_axes = zip(*elem)
# TODO(zhangqiaorjc): Due to limitations in XLA, 2D collectives only
Expand Down Expand Up @@ -798,6 +812,7 @@ def create_device_mesh(
physical_mesh,
new_mesh_shape,
allow_split_physical_axes=allow_split_physical_axes,
device_kind=last_device.device_kind,
)
return device_mesh
else:
Expand Down
25 changes: 25 additions & 0 deletions tests/mesh_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,31 @@ def test_create_contiguous_submeshes_errors(self):
mesh_shape, devices=devices, contiguous_submeshes=True
)

@parameterized.named_parameters(
# <-logical-> <-physical->
('1x1x2', [1, 1, 2], [1, 1, 1], [[[0, 1]]]),
('2x1x4', [2, 1, 4], [2, 2, 1], [[[0, 1, 2, 3]], [[7, 6, 5, 4]]]),
('4x1x2', [4, 1, 2], [2, 2, 1], [[[0, 1]], [[2, 3]],
[[7, 6]], [[5, 4]]]),
('4x4x2', [4, 2, 2], [2, 2, 2], [[[0, 1], [8, 9]], [[2, 3], [10, 11]],
[[4, 5], [12, 13]], [[6, 7], [14, 15]]]),
)
def test_v7x_create_device_mesh(
self, logical_mesh_shape, physical_mesh_shape, expected_device_id_mesh
):
global_devices = mock_tpu_devices(
physical_mesh_shape[0],
physical_mesh_shape[1],
physical_mesh_shape[2],
mesh_utils._TPU_7X,
one_device_per_chip=False,
)
mesh = mesh_utils.create_device_mesh(
logical_mesh_shape, devices=global_devices, contiguous_submeshes=False
)
device_id_mesh = np.vectorize(lambda d: d.id)(mesh)
self.assertAllClose(np.array(expected_device_id_mesh), device_id_mesh)


def int64_array(x) -> np.ndarray:
return np.array(x, dtype=np.int64)
Expand Down
Loading