File tree Expand file tree Collapse file tree 2 files changed +5
-1
lines changed Expand file tree Collapse file tree 2 files changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -471,7 +471,8 @@ ici_expert_parallelism: 1
471471# Enable ZeRO-1 optimizer sharding over data axis
472472shard_optimizer_over_data : False
473473
474- # The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation,
474+ # Unless explicitly specified, the number of TPU slices is automatically determined. It should only be set for
475+ # disaggregated reinforcement learning workloads using multiple slices. For ahead of time compilation,
475476# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1.
476477num_slices : -1
477478
Original file line number Diff line number Diff line change @@ -289,6 +289,9 @@ def _retrieve_jax_init_info(raw_keys):
289289
290290def get_num_slices (raw_keys ):
291291 """Calculate num_slices based on number of devices."""
292+ if raw_keys ["num_slices" ] != - 1 :
293+ max_logging .log (f"Using num_slices={ raw_keys ['num_slices' ]} per user request." )
294+ return raw_keys ["num_slices" ]
292295 if raw_keys ["hardware" ] == "cpu" :
293296 max_logging .log (" Setting num_slices=1 for CPU hardware type" )
294297 return 1
You can’t perform that action at this time.
0 commit comments