Skip to content

Commit 6beec62

Browse files
mwhittakerGoogle-ML-Automation
authored andcommitted
Add a JAX config to disable the preemption service.
PiperOrigin-RevId: 843356947
1 parent f3d83cd commit 6beec62

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

jax/_src/distributed.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@
4444
),
4545
)
4646

47+
_ENABLE_PREEMPTION_SERVICE = config.bool_state(
48+
name='jax_enable_preemption_service',
49+
default=True,
50+
help=(
51+
"Enables the preemption service. See"
52+
" multihost_utils.reached_preemption_sync_point for details."
53+
),
54+
)
55+
4756
class State:
4857
process_id: int = 0
4958
num_processes: int = 1
@@ -188,6 +197,8 @@ def shutdown(self):
188197
self.service = None
189198

190199
def initialize_preemption_sync_manager(self):
200+
if not _ENABLE_PREEMPTION_SERVICE.value:
201+
return
191202
if self.preemption_sync_manager is not None:
192203
raise RuntimeError(
193204
'Preemption sync manager should only be initialized once.')

jax/experimental/multihost_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,10 @@ def should_save(step_id: int) -> bool:
226226
return False
227227
sync_manager = distributed.global_state.preemption_sync_manager
228228
if sync_manager is None:
229-
raise RuntimeError("Preemption sync manager has not been initialized.")
229+
raise RuntimeError(
230+
"Preemption sync manager has not been initialized. Make sure the"
231+
" 'jax_enable_preemption_service' config is enabled."
232+
)
230233
return sync_manager.reached_sync_point(step_id)
231234

232235

0 commit comments

Comments
 (0)