Skip to content

Commit 171a883

Browse files
authored
Take job config out of checkpoint manager (#1433)
This PR takes job_config out of the CheckpointManager class. Why? JobConfig is a monolith -- it has knowledge of every part of a titan training job. As a result, it is hard to actually use CheckpointManager in a standalone fashion. In practice the job config is mostly only used for its checkpoint config, plus two other usages as far as I can tell: 1) Getting the replica_id from the FTManager 2) Taking the dump_folder from the job field and joining it with the checkpoint folder For (1) we can just get this directly from FTManager without accessing the JobConfig field. For (2) we can pass `job_config.job.dump_folder` explicitly as a base folder, then join to `checkpoint_config.folder`. Personally I would try to consolidate `job.dump_folder` and `checkpoint.folder` (though I understand there are cases where only the former is needed) under Checkpoint, but not sure if this is preferable from titan's pov.
1 parent 177b295 commit 171a883

File tree

3 files changed

+62
-38
lines changed

3 files changed

+62
-38
lines changed

tests/unit_tests/test_checkpoint.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ def test_save_load_restores_state(self, mock_load, mock_save, mock_rank):
175175
optimizers=self.optimizers,
176176
lr_schedulers=self.lr_schedulers,
177177
states=self.states,
178-
job_config=self.job_config,
178+
checkpoint_config=self.job_config.checkpoint,
179+
base_folder=self.job_config.job.dump_folder,
179180
ft_manager=self.ft_manager,
180181
)
181182

@@ -207,7 +208,8 @@ def test_save_and_purge_keeps_last_k_checkpoints(
207208
optimizers=self.optimizers,
208209
lr_schedulers=self.lr_schedulers,
209210
states=self.states,
210-
job_config=self.job_config,
211+
checkpoint_config=self.job_config.checkpoint,
212+
base_folder=self.job_config.job.dump_folder,
211213
ft_manager=self.ft_manager,
212214
)
213215

@@ -247,7 +249,8 @@ def test_nonzero_rank_does_not_purge_or_save(self, mock_load, mock_save, mock_ra
247249
optimizers=self.optimizers,
248250
lr_schedulers=self.lr_schedulers,
249251
states=self.states,
250-
job_config=self.job_config,
252+
checkpoint_config=self.job_config.checkpoint,
253+
base_folder=self.job_config.job.dump_folder,
251254
ft_manager=self.ft_manager,
252255
)
253256
manager.save(curr_step=1)
@@ -269,7 +272,8 @@ def test_load_returns_false_when_no_checkpoint_folder(self):
269272
optimizers=self.optimizers,
270273
lr_schedulers=self.lr_schedulers,
271274
states=self.states,
272-
job_config=self.job_config,
275+
checkpoint_config=self.job_config.checkpoint,
276+
base_folder=self.job_config.job.dump_folder,
273277
ft_manager=self.ft_manager,
274278
)
275279
self.assertFalse(manager.load(step=-1))
@@ -292,7 +296,8 @@ def test_load_finds_latest_and_calls_dcp_load(self, mock_load, mock_rank):
292296
optimizers=self.optimizers,
293297
lr_schedulers=self.lr_schedulers,
294298
states=self.states,
295-
job_config=self.job_config,
299+
checkpoint_config=self.job_config.checkpoint,
300+
base_folder=self.job_config.job.dump_folder,
296301
ft_manager=self.ft_manager,
297302
)
298303
res = manager.load(step=-1)
@@ -321,7 +326,8 @@ def test_interval_respects_interval(self, mock_load, mock_save, mock_rank):
321326
optimizers=self.optimizers,
322327
lr_schedulers=self.lr_schedulers,
323328
states=self.states,
324-
job_config=self.job_config,
329+
checkpoint_config=self.job_config.checkpoint,
330+
base_folder=self.job_config.job.dump_folder,
325331
ft_manager=self.ft_manager,
326332
)
327333
manager.save(curr_step=1)
@@ -354,7 +360,8 @@ def test_last_save_model_only_and_initial_load_model_only(
354360
optimizers=self.optimizers,
355361
lr_schedulers=self.lr_schedulers,
356362
states=self.states,
357-
job_config=self.job_config,
363+
checkpoint_config=self.job_config.checkpoint,
364+
base_folder=self.job_config.job.dump_folder,
358365
ft_manager=self.ft_manager,
359366
)
360367
manager1.save(curr_step=1, last_step=True)
@@ -373,7 +380,8 @@ def test_last_save_model_only_and_initial_load_model_only(
373380
optimizers=self.optimizers,
374381
lr_schedulers=self.lr_schedulers,
375382
states=self.states,
376-
job_config=self.job_config,
383+
checkpoint_config=self.job_config.checkpoint,
384+
base_folder=self.job_config.job.dump_folder,
377385
ft_manager=self.ft_manager,
378386
)
379387
r1 = manager2.load(step=1)
@@ -404,7 +412,8 @@ def test_async_save_calls_async_wait(self, mock_async_save, mock_new_group):
404412
"""
405413
# Configure async mode
406414
job_config = DummyJobConfig(job=self.job_config.job)
407-
job_config.checkpoint.async_mode = "async"
415+
checkpoint_config = job_config.checkpoint
416+
checkpoint_config.async_mode = "async"
408417
ft_manager = DummyFTManager()
409418
states = {"trainer": torch.tensor([0])}
410419
manager = CheckpointManager(
@@ -413,8 +422,9 @@ def test_async_save_calls_async_wait(self, mock_async_save, mock_new_group):
413422
optimizers=self.optimizers,
414423
lr_schedulers=self.lr_schedulers,
415424
states=states,
416-
job_config=job_config,
417-
ft_manager=ft_manager,
425+
checkpoint_config=checkpoint_config,
426+
base_folder=self.job_config.job.dump_folder,
427+
ft_manager=self.ft_manager,
418428
)
419429

420430
# First save schedules async
@@ -445,7 +455,8 @@ def test_ft_async_save_calls_async_wait(
445455
Test that with FT enabled, AsyncMode.ASYNC via FT triggers correct waits.
446456
"""
447457
job_config = DummyJobConfig(job=self.job_config.job)
448-
job_config.checkpoint.async_mode = "async"
458+
checkpoint_config = job_config.checkpoint
459+
checkpoint_config.async_mode = "async"
449460
ft_manager = mock.Mock()
450461
ft_manager.manager.return_value = mock.Mock()
451462
ft_manager.manager.participating_rank = mock.Mock(return_value=0)
@@ -456,8 +467,9 @@ def test_ft_async_save_calls_async_wait(
456467
optimizers=self.optimizers,
457468
lr_schedulers=self.lr_schedulers,
458469
states=self.states,
459-
job_config=job_config,
460-
ft_manager=ft_manager,
470+
checkpoint_config=checkpoint_config,
471+
base_folder=self.job_config.job.dump_folder,
472+
ft_manager=self.ft_manager,
461473
)
462474

463475
# Initially no future
@@ -491,7 +503,8 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank):
491503
optimizers=self.optimizers,
492504
lr_schedulers=self.lr_schedulers,
493505
states=self.states,
494-
job_config=self.job_config,
506+
checkpoint_config=self.job_config.checkpoint,
507+
base_folder=self.job_config.job.dump_folder,
495508
ft_manager=self.ft_manager,
496509
)
497510

@@ -516,7 +529,8 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank):
516529
optimizers=self.optimizers,
517530
lr_schedulers=self.lr_schedulers,
518531
states=self.states,
519-
job_config=self.job_config,
532+
checkpoint_config=self.job_config.checkpoint,
533+
base_folder=self.job_config.job.dump_folder,
520534
ft_manager=self.ft_manager,
521535
)
522536

@@ -561,7 +575,8 @@ def __init__(self):
561575
optimizers=self.optimizers,
562576
lr_schedulers=self.lr_schedulers,
563577
states=self.states,
564-
job_config=self.job_config,
578+
checkpoint_config=self.job_config.checkpoint,
579+
base_folder=self.job_config.job.dump_folder,
565580
ft_manager=self.ft_manager,
566581
)
567582

@@ -610,7 +625,8 @@ def fake_load(state_dict: dict, checkpoint_id=None):
610625
optimizers=self.optimizers,
611626
lr_schedulers=self.lr_schedulers,
612627
states=self.states,
613-
job_config=self.job_config,
628+
checkpoint_config=self.job_config.checkpoint,
629+
base_folder=self.job_config.job.dump_folder,
614630
ft_manager=self.ft_manager,
615631
)
616632

torchtitan/components/checkpoint.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from torchtitan.components.ft import FTManager
3737
from torchtitan.components.lr_scheduler import LRSchedulersContainer
3838
from torchtitan.components.optimizer import OptimizersContainer
39-
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
39+
from torchtitan.config_manager import Checkpoint, TORCH_DTYPE_MAP
4040
from torchtitan.protocols.state_dict_adapter import StateDictAdapter
4141
from torchtitan.tools.logging import logger
4242
from torchtitan.tools.utils import GarbageCollection
@@ -174,10 +174,13 @@ class CheckpointManager:
174174
lr_schedulers (LRSchedulersContainer): The lr schedulers used to optimize the model.
175175
states (Dict[str, Any]): The states that need to be saved, other than the
176176
previous 4 components.
177-
job_config (JobConfig): The job config used to configure the checkpointing.
177+
checkpoint_config (Checkpoint): The config used to configure the checkpointing.
178+
base_folder (str): The base folder to save the checkpoint. Will be concatenated
179+
with checkpoint_config.folder
178180
sd_adapter (Optional[type[StateDictAdapter]]): The adapter used to convert model state
179181
dicts between native format and other formats.
180182
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
183+
181184
"""
182185

183186
def __init__(
@@ -187,13 +190,13 @@ def __init__(
187190
optimizers: OptimizersContainer,
188191
lr_schedulers: LRSchedulersContainer,
189192
states: dict[str, Any],
190-
job_config: JobConfig,
193+
checkpoint_config: Checkpoint,
194+
base_folder: str,
191195
sd_adapter: type[StateDictAdapter] | None = None,
192196
ft_manager: FTManager | None = None,
193197
) -> None:
194-
ckpt_config = job_config.checkpoint
195-
self.enable_checkpoint = ckpt_config.enable_checkpoint
196-
self.last_save_in_hf = ckpt_config.last_save_in_hf
198+
self.enable_checkpoint = checkpoint_config.enable_checkpoint
199+
self.last_save_in_hf = checkpoint_config.last_save_in_hf
197200
if self.last_save_in_hf:
198201
assert (
199202
sd_adapter is not None
@@ -224,9 +227,9 @@ def load_state_dict(state_dict):
224227
self.states[k].load_state_dict(v)
225228

226229
self.ft_manager.set_state_dict_fns(load_state_dict, state_dict)
227-
self.ft_replica_id = job_config.fault_tolerance.replica_id
230+
self.ft_replica_id = ft_manager.replica_id
228231

229-
async_mode = ckpt_config.async_mode.lower()
232+
async_mode = checkpoint_config.async_mode.lower()
230233
self.enable_staging = (
231234
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
232235
) or self.ft_manager
@@ -251,27 +254,29 @@ def load_state_dict(state_dict):
251254
self.cpu_offload_state_dict = None
252255
self.stager = None
253256

254-
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
257+
self.folder = os.path.join(base_folder, checkpoint_config.folder)
255258

256259
# Checkpoint policy related fields.
257-
self.initial_load_path = ckpt_config.initial_load_path
258-
self.initial_load_model_only = ckpt_config.initial_load_model_only
259-
self.last_save_model_only = ckpt_config.last_save_model_only
260-
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
261-
self.exclude_from_loading = ckpt_config.exclude_from_loading
262-
self.interval = ckpt_config.interval
263-
self.enable_first_step_checkpoint = ckpt_config.enable_first_step_checkpoint
260+
self.initial_load_path = checkpoint_config.initial_load_path
261+
self.initial_load_model_only = checkpoint_config.initial_load_model_only
262+
self.last_save_model_only = checkpoint_config.last_save_model_only
263+
self.export_dtype = TORCH_DTYPE_MAP[checkpoint_config.export_dtype]
264+
self.exclude_from_loading = checkpoint_config.exclude_from_loading
265+
self.interval = checkpoint_config.interval
266+
self.enable_first_step_checkpoint = (
267+
checkpoint_config.enable_first_step_checkpoint
268+
)
264269

265270
# Async checkpoint related fields.
266-
async_mode = ckpt_config.async_mode.lower()
271+
async_mode = checkpoint_config.async_mode.lower()
267272
if (
268273
async_mode == AsyncMode.ASYNC
269274
or async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
270275
or self.ft_manager
271276
):
272277
self.pg = dist.new_group(backend="gloo")
273278

274-
self.keep_latest_k = ckpt_config.keep_latest_k
279+
self.keep_latest_k = checkpoint_config.keep_latest_k
275280
if self.keep_latest_k > 0:
276281
if self.keep_latest_k == 1:
277282
raise ValueError(
@@ -296,7 +301,9 @@ def load_state_dict(state_dict):
296301
elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
297302
self.async_mode = AsyncMode.ASYNC_WITH_PINNED_MEM
298303
else:
299-
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")
304+
raise ValueError(
305+
f"Unkown checkpoint async_mode {checkpoint_config.async_mode}"
306+
)
300307

301308
logger.info(
302309
f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}"

torchtitan/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ def __init__(self, job_config: JobConfig):
294294
optimizers=self.optimizers,
295295
lr_schedulers=self.lr_schedulers,
296296
states={"train_state": self},
297-
job_config=job_config,
297+
checkpoint_config=job_config.checkpoint,
298+
base_folder=job_config.job.dump_folder,
298299
sd_adapter=self.train_spec.state_dict_adapter,
299300
ft_manager=self.ft_manager,
300301
)

0 commit comments

Comments
 (0)