@@ -175,7 +175,8 @@ def test_save_load_restores_state(self, mock_load, mock_save, mock_rank):
175
175
optimizers = self .optimizers ,
176
176
lr_schedulers = self .lr_schedulers ,
177
177
states = self .states ,
178
- job_config = self .job_config ,
178
+ checkpoint_config = self .job_config .checkpoint ,
179
+ base_folder = self .job_config .job .dump_folder ,
179
180
ft_manager = self .ft_manager ,
180
181
)
181
182
@@ -207,7 +208,8 @@ def test_save_and_purge_keeps_last_k_checkpoints(
207
208
optimizers = self .optimizers ,
208
209
lr_schedulers = self .lr_schedulers ,
209
210
states = self .states ,
210
- job_config = self .job_config ,
211
+ checkpoint_config = self .job_config .checkpoint ,
212
+ base_folder = self .job_config .job .dump_folder ,
211
213
ft_manager = self .ft_manager ,
212
214
)
213
215
@@ -247,7 +249,8 @@ def test_nonzero_rank_does_not_purge_or_save(self, mock_load, mock_save, mock_ra
247
249
optimizers = self .optimizers ,
248
250
lr_schedulers = self .lr_schedulers ,
249
251
states = self .states ,
250
- job_config = self .job_config ,
252
+ checkpoint_config = self .job_config .checkpoint ,
253
+ base_folder = self .job_config .job .dump_folder ,
251
254
ft_manager = self .ft_manager ,
252
255
)
253
256
manager .save (curr_step = 1 )
@@ -269,7 +272,8 @@ def test_load_returns_false_when_no_checkpoint_folder(self):
269
272
optimizers = self .optimizers ,
270
273
lr_schedulers = self .lr_schedulers ,
271
274
states = self .states ,
272
- job_config = self .job_config ,
275
+ checkpoint_config = self .job_config .checkpoint ,
276
+ base_folder = self .job_config .job .dump_folder ,
273
277
ft_manager = self .ft_manager ,
274
278
)
275
279
self .assertFalse (manager .load (step = - 1 ))
@@ -292,7 +296,8 @@ def test_load_finds_latest_and_calls_dcp_load(self, mock_load, mock_rank):
292
296
optimizers = self .optimizers ,
293
297
lr_schedulers = self .lr_schedulers ,
294
298
states = self .states ,
295
- job_config = self .job_config ,
299
+ checkpoint_config = self .job_config .checkpoint ,
300
+ base_folder = self .job_config .job .dump_folder ,
296
301
ft_manager = self .ft_manager ,
297
302
)
298
303
res = manager .load (step = - 1 )
@@ -321,7 +326,8 @@ def test_interval_respects_interval(self, mock_load, mock_save, mock_rank):
321
326
optimizers = self .optimizers ,
322
327
lr_schedulers = self .lr_schedulers ,
323
328
states = self .states ,
324
- job_config = self .job_config ,
329
+ checkpoint_config = self .job_config .checkpoint ,
330
+ base_folder = self .job_config .job .dump_folder ,
325
331
ft_manager = self .ft_manager ,
326
332
)
327
333
manager .save (curr_step = 1 )
@@ -354,7 +360,8 @@ def test_last_save_model_only_and_initial_load_model_only(
354
360
optimizers = self .optimizers ,
355
361
lr_schedulers = self .lr_schedulers ,
356
362
states = self .states ,
357
- job_config = self .job_config ,
363
+ checkpoint_config = self .job_config .checkpoint ,
364
+ base_folder = self .job_config .job .dump_folder ,
358
365
ft_manager = self .ft_manager ,
359
366
)
360
367
manager1 .save (curr_step = 1 , last_step = True )
@@ -373,7 +380,8 @@ def test_last_save_model_only_and_initial_load_model_only(
373
380
optimizers = self .optimizers ,
374
381
lr_schedulers = self .lr_schedulers ,
375
382
states = self .states ,
376
- job_config = self .job_config ,
383
+ checkpoint_config = self .job_config .checkpoint ,
384
+ base_folder = self .job_config .job .dump_folder ,
377
385
ft_manager = self .ft_manager ,
378
386
)
379
387
r1 = manager2 .load (step = 1 )
@@ -404,7 +412,8 @@ def test_async_save_calls_async_wait(self, mock_async_save, mock_new_group):
404
412
"""
405
413
# Configure async mode
406
414
job_config = DummyJobConfig (job = self .job_config .job )
407
- job_config .checkpoint .async_mode = "async"
415
+ checkpoint_config = job_config .checkpoint
416
+ checkpoint_config .async_mode = "async"
408
417
ft_manager = DummyFTManager ()
409
418
states = {"trainer" : torch .tensor ([0 ])}
410
419
manager = CheckpointManager (
@@ -413,8 +422,9 @@ def test_async_save_calls_async_wait(self, mock_async_save, mock_new_group):
413
422
optimizers = self .optimizers ,
414
423
lr_schedulers = self .lr_schedulers ,
415
424
states = states ,
416
- job_config = job_config ,
417
- ft_manager = ft_manager ,
425
+ checkpoint_config = checkpoint_config ,
426
+ base_folder = self .job_config .job .dump_folder ,
427
+ ft_manager = self .ft_manager ,
418
428
)
419
429
420
430
# First save schedules async
@@ -445,7 +455,8 @@ def test_ft_async_save_calls_async_wait(
445
455
Test that with FT enabled, AsyncMode.ASYNC via FT triggers correct waits.
446
456
"""
447
457
job_config = DummyJobConfig (job = self .job_config .job )
448
- job_config .checkpoint .async_mode = "async"
458
+ checkpoint_config = job_config .checkpoint
459
+ checkpoint_config .async_mode = "async"
449
460
ft_manager = mock .Mock ()
450
461
ft_manager .manager .return_value = mock .Mock ()
451
462
ft_manager .manager .participating_rank = mock .Mock (return_value = 0 )
@@ -456,8 +467,9 @@ def test_ft_async_save_calls_async_wait(
456
467
optimizers = self .optimizers ,
457
468
lr_schedulers = self .lr_schedulers ,
458
469
states = self .states ,
459
- job_config = job_config ,
460
- ft_manager = ft_manager ,
470
+ checkpoint_config = checkpoint_config ,
471
+ base_folder = self .job_config .job .dump_folder ,
472
+ ft_manager = self .ft_manager ,
461
473
)
462
474
463
475
# Initially no future
@@ -491,7 +503,8 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank):
491
503
optimizers = self .optimizers ,
492
504
lr_schedulers = self .lr_schedulers ,
493
505
states = self .states ,
494
- job_config = self .job_config ,
506
+ checkpoint_config = self .job_config .checkpoint ,
507
+ base_folder = self .job_config .job .dump_folder ,
495
508
ft_manager = self .ft_manager ,
496
509
)
497
510
@@ -516,7 +529,8 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank):
516
529
optimizers = self .optimizers ,
517
530
lr_schedulers = self .lr_schedulers ,
518
531
states = self .states ,
519
- job_config = self .job_config ,
532
+ checkpoint_config = self .job_config .checkpoint ,
533
+ base_folder = self .job_config .job .dump_folder ,
520
534
ft_manager = self .ft_manager ,
521
535
)
522
536
@@ -561,7 +575,8 @@ def __init__(self):
561
575
optimizers = self .optimizers ,
562
576
lr_schedulers = self .lr_schedulers ,
563
577
states = self .states ,
564
- job_config = self .job_config ,
578
+ checkpoint_config = self .job_config .checkpoint ,
579
+ base_folder = self .job_config .job .dump_folder ,
565
580
ft_manager = self .ft_manager ,
566
581
)
567
582
@@ -610,7 +625,8 @@ def fake_load(state_dict: dict, checkpoint_id=None):
610
625
optimizers = self .optimizers ,
611
626
lr_schedulers = self .lr_schedulers ,
612
627
states = self .states ,
613
- job_config = self .job_config ,
628
+ checkpoint_config = self .job_config .checkpoint ,
629
+ base_folder = self .job_config .job .dump_folder ,
614
630
ft_manager = self .ft_manager ,
615
631
)
616
632
0 commit comments