@@ -329,13 +329,13 @@ def test_interval_respects_interval(self, mock_load, mock_save, mock_rank):
329
329
self .assertEqual (mock_save .call_count , 0 )
330
330
manager .save (curr_step = 2 )
331
331
self .assertEqual (mock_save .call_count , 0 )
332
- manager .save (curr_step = 2 , force = True )
332
+ manager .save (curr_step = 2 , last_step = True )
333
333
self .assertEqual (mock_save .call_count , 1 )
334
334
manager .save (curr_step = 3 )
335
335
self .assertEqual (mock_save .call_count , 2 )
336
336
manager .save (curr_step = 4 )
337
337
self .assertEqual (mock_save .call_count , 2 )
338
- manager .save (curr_step = 4 , force = True )
338
+ manager .save (curr_step = 4 , last_step = True )
339
339
self .assertEqual (mock_save .call_count , 3 )
340
340
manager .close ()
341
341
@@ -358,7 +358,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
358
358
job_config = self .job_config ,
359
359
ft_manager = self .ft_manager ,
360
360
)
361
- manager1 .save (curr_step = 1 , force = True )
361
+ manager1 .save (curr_step = 1 , last_step = True )
362
362
path1 = os .path .join (self .test_folder , "step-1" )
363
363
self .assertTrue (os .path .isdir (path1 ))
364
364
# Phase 2: initial load from step-1
@@ -383,7 +383,7 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only(
383
383
args1 , kwargs1 = mock_load .call_args
384
384
self .assertEqual (kwargs1 .get ("checkpoint_id" ), path1 )
385
385
# Phase 3: save new step under default folder, then load that
386
- manager2 .save (curr_step = 2 , force = True )
386
+ manager2 .save (curr_step = 2 , last_step = True )
387
387
# Default folder is test_folder, so step-2 under that
388
388
step2_dir = os .path .join (self .test_folder , "step-2" )
389
389
self .assertTrue (os .path .isdir (step2_dir ))
@@ -419,12 +419,12 @@ def test_async_save_calls_async_wait(self, mock_async_save, mock_new_group):
419
419
)
420
420
421
421
# First save schedules async
422
- manager .save (curr_step = 10 , force = False )
422
+ manager .save (curr_step = 10 , last_step = False )
423
423
future = manager .async_future
424
424
future .result .assert_not_called ()
425
425
426
426
# Second save should wait
427
- manager .save (curr_step = 20 , force = False )
427
+ manager .save (curr_step = 20 , last_step = False )
428
428
future .result .assert_called_once ()
429
429
430
430
# New future created
@@ -462,12 +462,12 @@ def test_ft_async_save_calls_async_wait(
462
462
463
463
# Initially no future
464
464
self .assertIsNone (manager .async_future )
465
- manager .save (curr_step = 5 , force = False )
465
+ manager .save (curr_step = 5 , last_step = False )
466
466
self .assertIsNotNone (manager .async_future )
467
467
468
468
manager .async_future .result .assert_not_called ()
469
469
prev_future = manager .async_future
470
- manager .save (curr_step = 6 , force = False )
470
+ manager .save (curr_step = 6 , last_step = False )
471
471
prev_future .result .assert_called_once ()
472
472
self .assertIsNotNone (manager .async_future )
473
473
manager .async_future .result .assert_not_called ()
0 commit comments