Skip to content

Commit d29bb51

Browse files
Merge pull request #2586 from AI-Hypercomputer:multislice-rl
PiperOrigin-RevId: 828047018
2 parents ff5be4a + 3a137c8 commit d29bb51

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

src/MaxText/configs/base.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,8 @@ ici_expert_parallelism: 1
471471
# Enable ZeRO-1 optimizer sharding over data axis
472472
shard_optimizer_over_data: False
473473

474-
# The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation,
474+
# Unless explicitly specified, the number of TPU slices is automatically determined. It should only be set for
475+
# disaggregated reinforcement learning workloads using multiple slices. For ahead of time compilation,
475476
# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1.
476477
num_slices: -1
477478

src/MaxText/max_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ def _retrieve_jax_init_info(raw_keys):
289289

290290
def get_num_slices(raw_keys):
291291
"""Calculate num_slices based on number of devices."""
292+
if raw_keys["num_slices"] != -1:
293+
max_logging.log(f"Using num_slices={raw_keys['num_slices']} per user request.")
294+
return raw_keys["num_slices"]
292295
if raw_keys["hardware"] == "cpu":
293296
max_logging.log(" Setting num_slices=1 for CPU hardware type")
294297
return 1

0 commit comments

Comments
 (0)