Skip to content

Commit 358622f

Browse files
tushar00jaingithubsgi
authored andcommitted
allow disabling ft checkpoints (pytorch#1810)
Summary: Allows disabling the storage of checkpoints related to torchft. Users don't really have to rely on any external storage. So it reduces set up time to get things up and running. Since we also don't really need model checkpoints when we have torchft. And if checkpoint storage has issues, this can work as a killswitch to completely disable the storage so it doesn't impact training. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1810). * pytorch#1856 * pytorch#1811 * __->__ pytorch#1810 Co-authored-by: Tushar Jain <[email protected]>
1 parent 4253e4d commit 358622f

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

torchtitan/components/checkpoint.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,19 @@ def __init__(
190190
self.load_only = checkpoint_config.load_only
191191

192192
self.ft_manager = (
193-
ft_manager.manager if ft_manager and ft_manager.enabled else None
193+
ft_manager.manager
194+
if ft_manager
195+
and ft_manager.enabled
196+
and checkpoint_config.enable_ft_dataloader_checkpoints
197+
else None
194198
)
199+
200+
if ft_manager and ft_manager.enabled and not self.ft_manager:
201+
logger.warn(
202+
"Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. "
203+
"This means replicas can retrain over the same data multiple times, which can result in overfitting."
204+
)
205+
195206
if self.ft_manager:
196207
optimizers.init_cache_state_dict()
197208

torchtitan/config/job_config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,28 @@ class Checkpoint:
421421
enable: bool = False
422422
"""Whether to enable checkpoint"""
423423

424+
enable_ft_dataloader_checkpoints: bool = True
425+
"""
426+
Warning: Disabling this can have fault tolerant replicas training
427+
over the same data multiple times. Use it with caution if training
428+
over the same data is acceptable.
429+
430+
Used to enable checkpointing the dataloader index for fault tolerant training with torchft.
431+
432+
Fault tolerant training stores data loader index in the checkpoints, so that training can resume
433+
without going over the same batch twice.
434+
435+
If enabled, data loader state is checkpointed. Otherwise, replicas
436+
will train over the same data multiple times, which can result in
437+
overfitting.
438+
439+
The failed replcia will still recover other state e.g. model
440+
parameters from other replcias.
441+
442+
Note, if regular checkpointing is enabled, we also checkpoint the
443+
data loader state. But when not using fault tolerance, the entire training starts from scratch.
444+
"""
445+
424446
folder: str = "checkpoint"
425447
"""
426448
The folder to store the checkpoints.

0 commit comments

Comments
 (0)