diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 715d831e4e9c..6c6f60c08c0a 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -44,6 +44,15 @@ ), ) +_ENABLE_PREEMPTION_SERVICE = config.bool_state( + name='jax_enable_preemption_service', + default=True, + help=( + "Enables the preemption service. See" + " multihost_utils.reached_preemption_sync_point for details." + ), +) + class State: process_id: int = 0 num_processes: int = 1 @@ -188,6 +197,8 @@ def shutdown(self): self.service = None def initialize_preemption_sync_manager(self): + if not _ENABLE_PREEMPTION_SERVICE.value: + return if self.preemption_sync_manager is not None: raise RuntimeError( 'Preemption sync manager should only be initialized once.') diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index ee7e9509ea3d..f3026502abc6 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -226,7 +226,10 @@ def should_save(step_id: int) -> bool: return False sync_manager = distributed.global_state.preemption_sync_manager if sync_manager is None: - raise RuntimeError("Preemption sync manager has not been initialized.") + raise RuntimeError( + "Preemption sync manager has not been initialized. Make sure the" + " 'jax_enable_preemption_service' config is enabled." + ) return sync_manager.reached_sync_point(step_id)