File tree Expand file tree Collapse file tree 2 files changed +15
-1
lines changed
Expand file tree Collapse file tree 2 files changed +15
-1
lines changed Original file line number Diff line number Diff line change 4444 ),
4545)
4646
47+ _ENABLE_PREEMPTION_SERVICE = config .bool_state (
48+ name = 'jax_enable_preemption_service' ,
49+ default = True ,
50+ help = (
51+ "Enables the preemption service. See"
52+ " multihost_utils.reached_preemption_sync_point for details."
53+ ),
54+ )
55+
4756class State :
4857 process_id : int = 0
4958 num_processes : int = 1
@@ -188,6 +197,8 @@ def shutdown(self):
188197 self .service = None
189198
190199 def initialize_preemption_sync_manager (self ):
200+ if not _ENABLE_PREEMPTION_SERVICE .value :
201+ return
191202 if self .preemption_sync_manager is not None :
192203 raise RuntimeError (
193204 'Preemption sync manager should only be initialized once.' )
Original file line number Diff line number Diff line change @@ -226,7 +226,10 @@ def should_save(step_id: int) -> bool:
226226 return False
227227 sync_manager = distributed .global_state .preemption_sync_manager
228228 if sync_manager is None :
229- raise RuntimeError ("Preemption sync manager has not been initialized." )
229+ raise RuntimeError (
230+ "Preemption sync manager has not been initialized. Make sure the"
231+ " 'jax_enable_preemption_service' config is enabled."
232+ )
230233 return sync_manager .reached_sync_point (step_id )
231234
232235
You can’t perform that action at this time.
0 commit comments