Skip to content

Commit 6304060

Browse files
Iterate in reverse order for TPU7x to preferentially map the innermost physical axis for higher bandwidth.
This is because the innermost axis (cores on a chip) has a higher bandwidth than other axes. PiperOrigin-RevId: 800032810
1 parent 8263fc4 commit 6304060

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

jax/_src/mesh_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,19 @@ def _7x_create_device_mesh(
201201
}
202202

203203

204+
def _reverse_combinations(iter, mesh_shape: Sequence[int]):
205+
"""Reverses the indices in the given iterator."""
206+
for elem in iter:
207+
indices, _ = zip(*elem)
208+
yield tuple((len(mesh_shape) - i - 1, mesh_shape[-i - 1]) for i in indices)
209+
210+
204211
def _create_device_mesh_for_nd_torus(
205212
physical_mesh: np.ndarray,
206213
mesh_shape: Sequence[int],
207214
*,
208215
allow_split_physical_axes: bool = False,
216+
device_kind: str | None = None,
209217
) -> tuple[np.ndarray, np.ndarray]:
210218
"""Assigns logical parallelism axes to physical axes of an N-D torus network.
211219
@@ -266,6 +274,12 @@ def _create_device_mesh_for_nd_torus(
266274
indices_and_axes = itertools.combinations(
267275
enumerate(assignable_physical_mesh), num_axes
268276
)
277+
if device_kind == _TPU_7X:
278+
# For TPU7x, the innermost physical axis (cores on chip) has higher
279+
# bandwidth. So, prioritize it if assignable.
280+
indices_and_axes = _reverse_combinations(
281+
indices_and_axes, assignable_physical_mesh
282+
)
269283
for elem in indices_and_axes:
270284
c_indices, c_axes = zip(*elem)
271285
# TODO(zhangqiaorjc): Due to limitations in XLA, 2D collectives only
@@ -798,6 +812,7 @@ def create_device_mesh(
798812
physical_mesh,
799813
new_mesh_shape,
800814
allow_split_physical_axes=allow_split_physical_axes,
815+
device_kind=last_device.device_kind,
801816
)
802817
return device_mesh
803818
else:

tests/mesh_utils_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,31 @@ def test_create_contiguous_submeshes_errors(self):
650650
mesh_shape, devices=devices, contiguous_submeshes=True
651651
)
652652

653+
@parameterized.named_parameters(
654+
# <-logical-> <-physical->
655+
('1x1x2', [1, 1, 2], [1, 1, 1], [[[0, 1]]]),
656+
('2x1x4', [2, 1, 4], [2, 2, 1], [[[0, 1, 2, 3]], [[7, 6, 5, 4]]]),
657+
('4x1x2', [4, 1, 2], [2, 2, 1], [[[0, 1]], [[2, 3]],
658+
[[7, 6]], [[5, 4]]]),
659+
('4x4x2', [4, 2, 2], [2, 2, 2], [[[0, 1], [8, 9]], [[2, 3], [10, 11]],
660+
[[4, 5], [12, 13]], [[6, 7], [14, 15]]]),
661+
)
662+
def test_v7x_create_device_mesh(
663+
self, logical_mesh_shape, physical_mesh_shape, expected_device_id_mesh
664+
):
665+
global_devices = mock_tpu_devices(
666+
physical_mesh_shape[0],
667+
physical_mesh_shape[1],
668+
physical_mesh_shape[2],
669+
mesh_utils._TPU_7X,
670+
one_device_per_chip=False,
671+
)
672+
mesh = mesh_utils.create_device_mesh(
673+
logical_mesh_shape, devices=global_devices, contiguous_submeshes=False
674+
)
675+
device_id_mesh = np.vectorize(lambda d: d.id)(mesh)
676+
self.assertAllClose(np.array(expected_device_id_mesh), device_id_mesh)
677+
653678

654679
def int64_array(x) -> np.ndarray:
655680
return np.array(x, dtype=np.int64)

0 commit comments

Comments
 (0)