From 0f7316746cb2dc7751abc3489166823fa9c2744d Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 19 Jun 2025 05:18:52 +0200 Subject: [PATCH 01/37] add saving of checkpoint if an exception is raised --- .../pytorch/callbacks/model_checkpoint.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 6b7b2831a2e04..02beaebfa6e34 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -97,6 +97,7 @@ class ModelCheckpoint(Checkpoint): collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid collisions. + save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``True``. mode: one of {min, max}. If ``save_top_k != 0``, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. @@ -224,6 +225,7 @@ def __init__( verbose: bool = False, save_last: Optional[Union[bool, Literal["link"]]] = None, save_top_k: int = 1, + save_on_exception: bool = True, save_weights_only: bool = False, mode: str = "min", auto_insert_metric_name: bool = True, @@ -238,6 +240,7 @@ def __init__( self.verbose = verbose self.save_last = save_last self.save_top_k = save_top_k + self.save_on_exception = save_on_exception self.save_weights_only = save_weights_only self.auto_insert_metric_name = auto_insert_metric_name self._save_on_train_epoch_end = save_on_train_epoch_end @@ -338,6 +341,17 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) + @override + def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: Exception) -> None: + if self.save_on_exception and not self._should_skip_saving_checkpoint(trainer): + monitor_candidates = self._monitor_candidates(trainer) + filepath = self.format_checkpoint_name(metrics=monitor_candidates) + print(type(exception)) + self._save_checkpoint(trainer, filepath) + self._save_last_checkpoint(trainer, monitor_candidates) + rank_zero_info(f"An exception was raised saved checkpoint to {filepath}") + + @override def state_dict(self) -> dict[str, Any]: return { From 136e59a9f1242d7aefb2301bc96cb7e45fb36fc8 Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 19 Jun 2025 06:20:11 +0200 Subject: [PATCH 02/37] import callback to checkpoint test file --- tests/tests_pytorch/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 7b17498865889..bf0e8cb521b78 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -35,7 +35,7 @@ import lightning.pytorch as pl from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.pytorch import Trainer, seed_everything -from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException From e0dae5350ffc978ba77f552d9bd4ccba9d2f19a2 Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 19 Jun 2025 06:20:36 +0200 Subject: [PATCH 03/37] add test for exception in training callbacks --- .../checkpointing/test_model_checkpoint.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index bf0e8cb521b78..57c370dc2da26 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -764,6 +764,55 @@ def test_ckpt_every_n_train_steps(tmp_path): assert set(os.listdir(tmp_path)) == set(expected) +def test_model_checkpoint_save_on_exception_in_train_callback(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a callback on different events.""" + class TroublemakerOnTrainBatchStart(Callback): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + class TroublemakerOnTrainBatchEnd(Callback): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + class TroublemakerOnTrainEpochStart(Callback): + def on_train_epoch_start(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + class TroublemakerOnTrainEpochEnd(Callback): + def on_train_epoch_end(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + + epoch_length = 64 + model = BoringModel() + # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + + troublemakers = [ + TroublemakerOnTrainBatchStart(), + TroublemakerOnTrainBatchEnd(), + TroublemakerOnTrainEpochStart(), + TroublemakerOnTrainEpochEnd() + ] + + expected_ckpts = ["step=1.ckpt", + 'step=2.ckpt', + f'step={epoch_length}.ckpt', + f'step={2*epoch_length}.ckpt', + ] + + for troublemaker in troublemakers: + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, troublemaker], max_epochs=5, logger=False) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + + assert set(os.listdir(tmp_path)) == set(expected_ckpts) + @mock.patch("lightning.pytorch.callbacks.model_checkpoint.time") def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None: """Tests that the checkpoints are saved at the specified time interval.""" From 2113acc2d225769c60b8726ddaf1c2e124f3882e Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 19 Jun 2025 06:58:56 +0200 Subject: [PATCH 04/37] split test for save checksave point on expection for expetions in training part of callbacks in individal test for better overview --- .../checkpointing/test_model_checkpoint.py | 114 +++++++++++++++--- 1 file changed, 94 insertions(+), 20 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 57c370dc2da26..f024553c032f4 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -764,52 +764,126 @@ def test_ckpt_every_n_train_steps(tmp_path): assert set(os.listdir(tmp_path)) == set(expected) -def test_model_checkpoint_save_on_exception_in_train_callback(tmp_path): - """Test that the checkpoint is saved when an exception is raised in a callback on different events.""" +def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start.""" class TroublemakerOnTrainBatchStart(Callback): def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if batch_idx == 1: raise RuntimeError("Trouble!") + model = BoringModel() + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()], max_epochs=5, logger=False) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert os.path.isfile(tmp_path / "step=1.ckpt") + + +def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_end(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_end.""" class TroublemakerOnTrainBatchEnd(Callback): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if batch_idx == 1: raise RuntimeError("Trouble!") + + model = BoringModel() + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()], max_epochs=5, logger=False) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert os.path.isfile(tmp_path / "step=2.ckpt") + + +def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_start(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_start.""" class TroublemakerOnTrainEpochStart(Callback): def on_train_epoch_start(self, trainer, pl_module): if trainer.current_epoch == 1: raise RuntimeError("Trouble!") + + model = BoringModel() + epoch_length = 64 + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()], max_epochs=5, logger=False) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") + +def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_end(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_end.""" class TroublemakerOnTrainEpochEnd(Callback): def on_train_epoch_end(self, trainer, pl_module): if trainer.current_epoch == 1: raise RuntimeError("Trouble!") - - epoch_length = 64 model = BoringModel() - # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints + epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()], max_epochs=5, logger=False) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert os.path.isfile(tmp_path / f"step={2*epoch_length}.ckpt") + + +# def test_model_checkpoint_save_on_exception_in_train_callback(tmp_path): +# """Test that the checkpoint is saved when an exception is raised in a callback on different events.""" +# class TroublemakerOnTrainBatchStart(Callback): +# def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): +# if batch_idx == 1: +# raise RuntimeError("Trouble!") + +# class TroublemakerOnTrainBatchEnd(Callback): +# def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): +# if batch_idx == 1: +# raise RuntimeError("Trouble!") + +# class TroublemakerOnTrainEpochStart(Callback): +# def on_train_epoch_start(self, trainer, pl_module): +# if trainer.current_epoch == 1: +# raise RuntimeError("Trouble!") + +# class TroublemakerOnTrainEpochEnd(Callback): +# def on_train_epoch_end(self, trainer, pl_module): +# if trainer.current_epoch == 1: +# raise RuntimeError("Trouble!") + + +# epoch_length = 64 +# model = BoringModel() +# # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints +# checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + +# troublemakers = [ +# TroublemakerOnTrainBatchStart(), +# TroublemakerOnTrainBatchEnd(), +# TroublemakerOnTrainEpochStart(), +# TroublemakerOnTrainEpochEnd() +# ] + +# expected_ckpts = ["step=1.ckpt", +# 'step=2.ckpt', +# f'step={epoch_length}.ckpt', +# f'step={2*epoch_length}.ckpt', +# ] + +# for troublemaker in troublemakers: +# trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, troublemaker], max_epochs=5, logger=False) + +# with pytest.raises(RuntimeError, match="Trouble!"): +# trainer.fit(model) - troublemakers = [ - TroublemakerOnTrainBatchStart(), - TroublemakerOnTrainBatchEnd(), - TroublemakerOnTrainEpochStart(), - TroublemakerOnTrainEpochEnd() - ] +# assert set(os.listdir(tmp_path)) == set(expected_ckpts) - expected_ckpts = ["step=1.ckpt", - 'step=2.ckpt', - f'step={epoch_length}.ckpt', - f'step={2*epoch_length}.ckpt', - ] - for troublemaker in troublemakers: - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, troublemaker], max_epochs=5, logger=False) +# # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints +# checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) +# troublemakers = [ +# # TroublemakerOnValidationBatchStart(), +# TroublemakerOnValidationBatchEnd(), +# expected_ckpts = [f"step={2*epoch_length}.ckpt", assert set(os.listdir(tmp_path)) == set(expected_ckpts) From 7d750e6917ba47158754ef3a7b8d59ee8795d35f Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 19 Jun 2025 07:08:10 +0200 Subject: [PATCH 05/37] add extra condition for checking if we should save on exception --- src/lightning/pytorch/callbacks/model_checkpoint.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 02beaebfa6e34..307ba7c5d559c 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -343,10 +343,9 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul @override def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: Exception) -> None: - if self.save_on_exception and not self._should_skip_saving_checkpoint(trainer): + if self._should_save_on_exception(trainer): monitor_candidates = self._monitor_candidates(trainer) filepath = self.format_checkpoint_name(metrics=monitor_candidates) - print(type(exception)) self._save_checkpoint(trainer, filepath) self._save_last_checkpoint(trainer, monitor_candidates) rank_zero_info(f"An exception was raised saved checkpoint to {filepath}") @@ -439,6 +438,14 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool: or trainer.sanity_checking # don't save anything during sanity check or self._last_global_step_saved == trainer.global_step # already saved at the last step ) + + def _should_save_on_exception(self, trainer: "pl.Trainer") -> bool: + return ( + self.save_on_exception + and not bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run + and not trainer.sanity_checking # don't save anything during sanity check + and not self._last_global_step_saved == trainer.global_step # already saved at the last step) + ) def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool: if self._save_on_train_epoch_end is not None: From 34e598ac181fd862b234176af1a276b6e6b98bfe Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 19 Jun 2025 07:09:26 +0200 Subject: [PATCH 06/37] add for saving checkpoint on exeption if the exception occurs in a validation callback --- .../checkpointing/test_model_checkpoint.py | 134 +++++++++++------- 1 file changed, 84 insertions(+), 50 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index f024553c032f4..537c026b50462 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -827,66 +827,100 @@ def on_train_epoch_end(self, trainer, pl_module): assert os.path.isfile(tmp_path / f"step={2*epoch_length}.ckpt") -# def test_model_checkpoint_save_on_exception_in_train_callback(tmp_path): -# """Test that the checkpoint is saved when an exception is raised in a callback on different events.""" -# class TroublemakerOnTrainBatchStart(Callback): -# def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): -# if batch_idx == 1: -# raise RuntimeError("Trouble!") +def test_model_checkpoint_save_on_exception_in_val_callback(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_start.""" + class TroublemakerOnValidationBatchStart(Callback): + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") -# class TroublemakerOnTrainBatchEnd(Callback): -# def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): -# if batch_idx == 1: -# raise RuntimeError("Trouble!") - -# class TroublemakerOnTrainEpochStart(Callback): -# def on_train_epoch_start(self, trainer, pl_module): -# if trainer.current_epoch == 1: -# raise RuntimeError("Trouble!") - -# class TroublemakerOnTrainEpochEnd(Callback): -# def on_train_epoch_end(self, trainer, pl_module): -# if trainer.current_epoch == 1: -# raise RuntimeError("Trouble!") - - -# epoch_length = 64 -# model = BoringModel() -# # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints -# checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - -# troublemakers = [ -# TroublemakerOnTrainBatchStart(), -# TroublemakerOnTrainBatchEnd(), -# TroublemakerOnTrainEpochStart(), -# TroublemakerOnTrainEpochEnd() -# ] + model = BoringModel() + epoch_length = 64 + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()], max_epochs=5, logger=False) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") -# expected_ckpts = ["step=1.ckpt", -# 'step=2.ckpt', -# f'step={epoch_length}.ckpt', -# f'step={2*epoch_length}.ckpt', -# ] -# for troublemaker in troublemakers: -# trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, troublemaker], max_epochs=5, logger=False) +def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_batch_end(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_end.""" + class TroublemakerOnValidationBatchEnd(Callback): + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + model = BoringModel() + epoch_length = 64 + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()], max_epochs=5, logger=False) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") -# with pytest.raises(RuntimeError, match="Trouble!"): -# trainer.fit(model) -# assert set(os.listdir(tmp_path)) == set(expected_ckpts) +def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_start(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_start.""" + class TroublemakerOnValidationEpochStart(Callback): + def on_validation_epoch_start(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == 0: + raise RuntimeError("Trouble!") + model = BoringModel() + epoch_length = 64 + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()], max_epochs=5, logger=False) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") + -# # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints -# checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) +def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end.""" + class TroublemakerOnValidationEpochEnd(Callback): + def on_validation_epoch_end(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == 0: + raise RuntimeError("Trouble!") + + model = BoringModel() + epoch_length = 64 + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()], max_epochs=5, logger=False) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") -# troublemakers = [ -# # TroublemakerOnValidationBatchStart(), -# TroublemakerOnValidationBatchEnd(), -# expected_ckpts = [f"step={2*epoch_length}.ckpt", +def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_start(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a callback on validation_start.""" + class TroublemakerOnValidationStart(Callback): + def on_validation_start(self, trainer, pl_module): + if not trainer.sanity_checking: + raise RuntimeError("Trouble!") + + model = BoringModel() + epoch_length = 64 + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationStart()], max_epochs=5, logger=False) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") - assert set(os.listdir(tmp_path)) == set(expected_ckpts) +def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_end(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a callback on validation_end.""" + class TroublemakerOnValidationEnd(Callback): + def on_validation_end(self, trainer, pl_module): + if not trainer.sanity_checking: + raise RuntimeError("Trouble!") + + model = BoringModel() + epoch_length = 64 + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()], max_epochs=5, logger=False) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") + @mock.patch("lightning.pytorch.callbacks.model_checkpoint.time") def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None: """Tests that the checkpoints are saved at the specified time interval.""" From d4d933be0afd25c7ffb539bbbd535a1f278b8cf8 Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 19 Jun 2025 07:36:23 +0200 Subject: [PATCH 07/37] add test for save model chekpoint on exception for exception in train and val step --- .../checkpointing/test_model_checkpoint.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 537c026b50462..cfc5d969799de 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -764,6 +764,37 @@ def test_ckpt_every_n_train_steps(tmp_path): assert set(os.listdir(tmp_path)) == set(expected) +def test_model_checkpoint_save_on_exception_in_training_step(tmp_path): + """Test that the checkpoint is saved when an exception is raised in training_step.""" + class TroubledModel(BoringModel): + def training_step(self, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + model = TroubledModel() + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, logger=False) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + print(os.listdir(tmp_path)) + assert os.path.isfile(tmp_path / "step=1.ckpt") + +def test_model_checkpoint_save_on_exception_in_validation_step(tmp_path): + """Test that the checkpoint is saved when an exception is raised in validation_step.""" + class TroubledModel(BoringModel): + def validation_step(self, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 0: + raise RuntimeError("Trouble!") + + model = TroubledModel() + epoch_length = 64 + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, logger=False) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") + + def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start.""" class TroublemakerOnTrainBatchStart(Callback): From 9f6063b83cc8f22323b7ca11245ac0597910d7ef Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 19 Jun 2025 07:41:20 +0200 Subject: [PATCH 08/37] disable trainer prog bar for test of model checkpoint on exception --- .../checkpointing/test_model_checkpoint.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index cfc5d969799de..d7852cee30f60 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -773,7 +773,7 @@ def training_step(self, batch, batch_idx): model = TroubledModel() checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, logger=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) print(os.listdir(tmp_path)) @@ -789,7 +789,7 @@ def validation_step(self, batch, batch_idx): model = TroubledModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, logger=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -804,7 +804,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): model = BoringModel() checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()], max_epochs=5, logger=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()], max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / "step=1.ckpt") @@ -819,7 +819,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): model = BoringModel() checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()], max_epochs=5, logger=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()], max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) @@ -836,7 +836,7 @@ def on_train_epoch_start(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()], max_epochs=5, logger=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()], max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -852,7 +852,7 @@ def on_train_epoch_end(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()], max_epochs=5, logger=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()], max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={2*epoch_length}.ckpt") @@ -868,7 +868,7 @@ def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()], max_epochs=5, logger=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()], max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -884,7 +884,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()], max_epochs=5, logger=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()], max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -900,7 +900,7 @@ def on_validation_epoch_start(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()], max_epochs=5, logger=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()], max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -916,7 +916,7 @@ def on_validation_epoch_end(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()], max_epochs=5, logger=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()], max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -931,7 +931,7 @@ def on_validation_start(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationStart()], max_epochs=5, logger=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationStart()], max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -946,7 +946,7 @@ def on_validation_end(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()], max_epochs=5, logger=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()], max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") From 02477d5bd7da6ba5747e0283f12dcf7e5b4f1425 Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 19 Jun 2025 07:49:06 +0200 Subject: [PATCH 09/37] model checkpoint on eception split trainer setup over two lines --- .../checkpointing/test_model_checkpoint.py | 56 +++++++++++-------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index d7852cee30f60..1010e50ec168c 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -773,7 +773,8 @@ def training_step(self, batch, batch_idx): model = TroubledModel() checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], + max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) print(os.listdir(tmp_path)) @@ -785,15 +786,16 @@ class TroubledModel(BoringModel): def validation_step(self, batch, batch_idx): if not trainer.sanity_checking and batch_idx == 0: raise RuntimeError("Trouble!") - + model = TroubledModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], + max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") - + def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start.""" @@ -804,7 +806,8 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): model = BoringModel() checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()], max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()], + max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / "step=1.ckpt") @@ -816,10 +819,11 @@ class TroublemakerOnTrainBatchEnd(Callback): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if batch_idx == 1: raise RuntimeError("Trouble!") - + model = BoringModel() checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()], max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()], + max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) @@ -832,11 +836,12 @@ class TroublemakerOnTrainEpochStart(Callback): def on_train_epoch_start(self, trainer, pl_module): if trainer.current_epoch == 1: raise RuntimeError("Trouble!") - + model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()], max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()], + max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -852,7 +857,8 @@ def on_train_epoch_end(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()], max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()], + max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={2*epoch_length}.ckpt") @@ -868,7 +874,8 @@ def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()], max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()], + max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -880,11 +887,12 @@ class TroublemakerOnValidationBatchEnd(Callback): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if not trainer.sanity_checking and batch_idx == 1: raise RuntimeError("Trouble!") - + model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()], max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()], + max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -900,11 +908,12 @@ def on_validation_epoch_start(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()], max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()], + max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") - + def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end.""" @@ -912,11 +921,12 @@ class TroublemakerOnValidationEpochEnd(Callback): def on_validation_epoch_end(self, trainer, pl_module): if not trainer.sanity_checking and trainer.current_epoch == 0: raise RuntimeError("Trouble!") - + model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()], max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()], + max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -927,11 +937,12 @@ class TroublemakerOnValidationStart(Callback): def on_validation_start(self, trainer, pl_module): if not trainer.sanity_checking: raise RuntimeError("Trouble!") - + model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationStart()], max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationStart()], + max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -942,16 +953,17 @@ class TroublemakerOnValidationEnd(Callback): def on_validation_end(self, trainer, pl_module): if not trainer.sanity_checking: raise RuntimeError("Trouble!") - + model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()], max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()], + max_epochs=5, logger=False, enable_progress_bar=False) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") - + @mock.patch("lightning.pytorch.callbacks.model_checkpoint.time") def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None: """Tests that the checkpoints are saved at the specified time interval.""" From e5b0498e18ab7f82e74acf14fc7e323088487755 Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 19 Jun 2025 07:54:54 +0200 Subject: [PATCH 10/37] remove trainling braket from shoukd_save_on_eception condition --- src/lightning/pytorch/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 307ba7c5d559c..96a22bfe33870 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -438,13 +438,13 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool: or trainer.sanity_checking # don't save anything during sanity check or self._last_global_step_saved == trainer.global_step # already saved at the last step ) - + def _should_save_on_exception(self, trainer: "pl.Trainer") -> bool: return ( self.save_on_exception and not bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run and not trainer.sanity_checking # don't save anything during sanity check - and not self._last_global_step_saved == trainer.global_step # already saved at the last step) + and not self._last_global_step_saved == trainer.global_step # already saved at the last step ) def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool: From 985c1e1e051448862baef28688cb4190153aeac4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Jun 2025 06:33:51 +0000 Subject: [PATCH 11/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/callbacks/model_checkpoint.py | 9 +- .../checkpointing/test_model_checkpoint.py | 133 ++++++++++++++---- 2 files changed, 108 insertions(+), 34 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 96a22bfe33870..60cfd195dc05d 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -350,7 +350,6 @@ def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", e self._save_last_checkpoint(trainer, monitor_candidates) rank_zero_info(f"An exception was raised saved checkpoint to {filepath}") - @override def state_dict(self) -> dict[str, Any]: return { @@ -441,10 +440,10 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool: def _should_save_on_exception(self, trainer: "pl.Trainer") -> bool: return ( - self.save_on_exception - and not bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run - and not trainer.sanity_checking # don't save anything during sanity check - and not self._last_global_step_saved == trainer.global_step # already saved at the last step + self.save_on_exception + and not bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run + and not trainer.sanity_checking # don't save anything during sanity check + and self._last_global_step_saved != trainer.global_step # already saved at the last step ) def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool: diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 1010e50ec168c..f44076b102a75 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -766,6 +766,7 @@ def test_ckpt_every_n_train_steps(tmp_path): def test_model_checkpoint_save_on_exception_in_training_step(tmp_path): """Test that the checkpoint is saved when an exception is raised in training_step.""" + class TroubledModel(BoringModel): def training_step(self, batch, batch_idx): if batch_idx == 1: @@ -773,15 +774,22 @@ def training_step(self, batch, batch_idx): model = TroubledModel() checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], - max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback], + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) print(os.listdir(tmp_path)) assert os.path.isfile(tmp_path / "step=1.ckpt") + def test_model_checkpoint_save_on_exception_in_validation_step(tmp_path): """Test that the checkpoint is saved when an exception is raised in validation_step.""" + class TroubledModel(BoringModel): def validation_step(self, batch, batch_idx): if not trainer.sanity_checking and batch_idx == 0: @@ -790,8 +798,13 @@ def validation_step(self, batch, batch_idx): model = TroubledModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], - max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback], + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -799,6 +812,7 @@ def validation_step(self, batch, batch_idx): def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start.""" + class TroublemakerOnTrainBatchStart(Callback): def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if batch_idx == 1: @@ -806,8 +820,13 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): model = BoringModel() checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()], - max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()], + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / "step=1.ckpt") @@ -815,6 +834,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_end(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_end.""" + class TroublemakerOnTrainBatchEnd(Callback): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if batch_idx == 1: @@ -822,8 +842,13 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): model = BoringModel() checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()], - max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()], + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) @@ -832,6 +857,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_start(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_start.""" + class TroublemakerOnTrainEpochStart(Callback): def on_train_epoch_start(self, trainer, pl_module): if trainer.current_epoch == 1: @@ -840,8 +866,13 @@ def on_train_epoch_start(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()], - max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()], + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -849,6 +880,7 @@ def on_train_epoch_start(self, trainer, pl_module): def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_end(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_end.""" + class TroublemakerOnTrainEpochEnd(Callback): def on_train_epoch_end(self, trainer, pl_module): if trainer.current_epoch == 1: @@ -857,25 +889,36 @@ def on_train_epoch_end(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()], - max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()], + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) - assert os.path.isfile(tmp_path / f"step={2*epoch_length}.ckpt") + assert os.path.isfile(tmp_path / f"step={2 * epoch_length}.ckpt") def test_model_checkpoint_save_on_exception_in_val_callback(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_start.""" + class TroublemakerOnValidationBatchStart(Callback): def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): - if not trainer.sanity_checking and batch_idx == 1: - raise RuntimeError("Trouble!") + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()], - max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()], + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -883,16 +926,22 @@ def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_batch_end(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_end.""" + class TroublemakerOnValidationBatchEnd(Callback): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - if not trainer.sanity_checking and batch_idx == 1: - raise RuntimeError("Trouble!") + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()], - max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()], + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -900,6 +949,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_start(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_start.""" + class TroublemakerOnValidationEpochStart(Callback): def on_validation_epoch_start(self, trainer, pl_module): if not trainer.sanity_checking and trainer.current_epoch == 0: @@ -908,8 +958,13 @@ def on_validation_epoch_start(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()], - max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()], + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") @@ -917,6 +972,7 @@ def on_validation_epoch_start(self, trainer, pl_module): def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end.""" + class TroublemakerOnValidationEpochEnd(Callback): def on_validation_epoch_end(self, trainer, pl_module): if not trainer.sanity_checking and trainer.current_epoch == 0: @@ -925,14 +981,21 @@ def on_validation_epoch_end(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()], - max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()], + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") + def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_start(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on validation_start.""" + class TroublemakerOnValidationStart(Callback): def on_validation_start(self, trainer, pl_module): if not trainer.sanity_checking: @@ -941,14 +1004,21 @@ def on_validation_start(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationStart()], - max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, TroublemakerOnValidationStart()], + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") + def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_end(tmp_path): """Test that the checkpoint is saved when an exception is raised in a callback on validation_end.""" + class TroublemakerOnValidationEnd(Callback): def on_validation_end(self, trainer, pl_module): if not trainer.sanity_checking: @@ -957,8 +1027,13 @@ def on_validation_end(self, trainer, pl_module): model = BoringModel() epoch_length = 64 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()], - max_epochs=5, logger=False, enable_progress_bar=False) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()], + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") From f0502ecb67e4e174a1a86b841a14b06431ee61b0 Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 19 Jun 2025 22:35:51 +0200 Subject: [PATCH 12/37] Update save checkpoint on exception tests to use a shorter more precisly defined epoch lenght --- .../checkpointing/test_model_checkpoint.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index f44076b102a75..a64e8dc397321 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -796,12 +796,13 @@ def validation_step(self, batch, batch_idx): raise RuntimeError("Trouble!") model = TroubledModel() - epoch_length = 64 + epoch_length = 2 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) trainer = Trainer( default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, + limit_train_batches=epoch_length, logger=False, enable_progress_bar=False, ) @@ -864,12 +865,13 @@ def on_train_epoch_start(self, trainer, pl_module): raise RuntimeError("Trouble!") model = BoringModel() - epoch_length = 64 + epoch_length = 2 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) trainer = Trainer( default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()], max_epochs=5, + limit_train_batches=epoch_length, logger=False, enable_progress_bar=False, ) @@ -887,12 +889,13 @@ def on_train_epoch_end(self, trainer, pl_module): raise RuntimeError("Trouble!") model = BoringModel() - epoch_length = 64 + epoch_length = 2 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) trainer = Trainer( default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()], max_epochs=5, + limit_train_batches=epoch_length, logger=False, enable_progress_bar=False, ) @@ -956,12 +959,13 @@ def on_validation_epoch_start(self, trainer, pl_module): raise RuntimeError("Trouble!") model = BoringModel() - epoch_length = 64 + epoch_length = 2 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) trainer = Trainer( default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()], max_epochs=5, + limit_train_batches=epoch_length, logger=False, enable_progress_bar=False, ) @@ -979,12 +983,13 @@ def on_validation_epoch_end(self, trainer, pl_module): raise RuntimeError("Trouble!") model = BoringModel() - epoch_length = 64 + epoch_length = 2 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) trainer = Trainer( default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()], max_epochs=5, + limit_train_batches=epoch_length, logger=False, enable_progress_bar=False, ) @@ -1002,12 +1007,13 @@ def on_validation_start(self, trainer, pl_module): raise RuntimeError("Trouble!") model = BoringModel() - epoch_length = 64 + epoch_length = 2 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) trainer = Trainer( default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationStart()], max_epochs=5, + limit_train_batches=epoch_length, logger=False, enable_progress_bar=False, ) @@ -1025,12 +1031,13 @@ def on_validation_end(self, trainer, pl_module): raise RuntimeError("Trouble!") model = BoringModel() - epoch_length = 64 + epoch_length = 2 checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) trainer = Trainer( default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()], max_epochs=5, + limit_train_batches=epoch_length, logger=False, enable_progress_bar=False, ) From 99af7ed1e527caa636f9446446b294c0b5c880d9 Mon Sep 17 00:00:00 2001 From: vsey Date: Fri, 20 Jun 2025 04:30:22 +0200 Subject: [PATCH 13/37] switch default on save on checkpoint on exception to false to don't intefere with current checkpoint behavoir --- src/lightning/pytorch/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 60cfd195dc05d..c983c5276a80a 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -225,7 +225,7 @@ def __init__( verbose: bool = False, save_last: Optional[Union[bool, Literal["link"]]] = None, save_top_k: int = 1, - save_on_exception: bool = True, + save_on_exception: bool = False, save_weights_only: bool = False, mode: str = "min", auto_insert_metric_name: bool = True, From c092385df47b3648ff818b889c346961a12b4783 Mon Sep 17 00:00:00 2001 From: vsey Date: Fri, 20 Jun 2025 06:09:53 +0200 Subject: [PATCH 14/37] checkpoint on exception put callback tests into a pytest prametrization --- .../checkpointing/test_model_checkpoint.py | 304 ++++++------------ 1 file changed, 93 insertions(+), 211 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index a64e8dc397321..d23e77c2c4732 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -811,241 +811,123 @@ def validation_step(self, batch, batch_idx): assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") -def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start(tmp_path): - """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start.""" +CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX = 2 +CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH = 21 +CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS = 25 +CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES = 4 +assert CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX < CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES +assert CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH < CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS - class TroublemakerOnTrainBatchStart(Callback): - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): - if batch_idx == 1: - raise RuntimeError("Trouble!") - - model = BoringModel() - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()], - max_epochs=5, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - assert os.path.isfile(tmp_path / "step=1.ckpt") - - -def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_end(tmp_path): - """Test that the checkpoint is saved when an exception is raised in a callback on train_batch_end.""" - - class TroublemakerOnTrainBatchEnd(Callback): - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - if batch_idx == 1: - raise RuntimeError("Trouble!") - - model = BoringModel() - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()], - max_epochs=5, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - - assert os.path.isfile(tmp_path / "step=2.ckpt") - - -def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_start(tmp_path): - """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_start.""" - - class TroublemakerOnTrainEpochStart(Callback): - def on_train_epoch_start(self, trainer, pl_module): - if trainer.current_epoch == 1: - raise RuntimeError("Trouble!") - - model = BoringModel() - epoch_length = 2 - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()], - max_epochs=5, - limit_train_batches=epoch_length, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") - -def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_end(tmp_path): - """Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_end.""" - - class TroublemakerOnTrainEpochEnd(Callback): - def on_train_epoch_end(self, trainer, pl_module): - if trainer.current_epoch == 1: - raise RuntimeError("Trouble!") - - model = BoringModel() - epoch_length = 2 - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()], - max_epochs=5, - limit_train_batches=epoch_length, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - assert os.path.isfile(tmp_path / f"step={2 * epoch_length}.ckpt") - - -def test_model_checkpoint_save_on_exception_in_val_callback(tmp_path): - """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_start.""" - - class TroublemakerOnValidationBatchStart(Callback): - def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): - if not trainer.sanity_checking and batch_idx == 1: - raise RuntimeError("Trouble!") - - model = BoringModel() - epoch_length = 64 - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()], - max_epochs=5, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") +class TroublemakerOnTrainBatchStart(Callback): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + if batch_idx == CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX: + raise RuntimeError("Trouble!") -def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_batch_end(tmp_path): - """Test that the checkpoint is saved when an exception is raised in a callback on validation_batch_end.""" +class TroublemakerOnTrainBatchEnd(Callback): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if batch_idx == CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX: + raise RuntimeError("Trouble!") - class TroublemakerOnValidationBatchEnd(Callback): - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - if not trainer.sanity_checking and batch_idx == 1: - raise RuntimeError("Trouble!") - model = BoringModel() - epoch_length = 64 - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()], - max_epochs=5, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") +class TroublemakerOnTrainEpochStart(Callback): + def on_train_epoch_start(self, trainer, pl_module): + if trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH: + raise RuntimeError("Trouble!") -def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_start(tmp_path): - """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_start.""" +class TroublemakerOnTrainEpochEnd(Callback): + def on_train_epoch_end(self, trainer, pl_module): + if trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH: + raise RuntimeError("Trouble!") - class TroublemakerOnValidationEpochStart(Callback): - def on_validation_epoch_start(self, trainer, pl_module): - if not trainer.sanity_checking and trainer.current_epoch == 0: - raise RuntimeError("Trouble!") - model = BoringModel() - epoch_length = 2 - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()], - max_epochs=5, - limit_train_batches=epoch_length, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") +class TroublemakerOnTrainEnd(Callback): + def on_train_end(self, trainer, pl_module): + raise RuntimeError("Trouble!") -def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end(tmp_path): - """Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end.""" +class TroublemakerOnValidationBatchStart(Callback): + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") - class TroublemakerOnValidationEpochEnd(Callback): - def on_validation_epoch_end(self, trainer, pl_module): - if not trainer.sanity_checking and trainer.current_epoch == 0: - raise RuntimeError("Trouble!") - model = BoringModel() - epoch_length = 2 - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()], - max_epochs=5, - limit_train_batches=epoch_length, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") +class TroublemakerOnValidationBatchEnd(Callback): + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") -def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_start(tmp_path): - """Test that the checkpoint is saved when an exception is raised in a callback on validation_start.""" +class TroublemakerOnValidationEpochStart(Callback): + def on_validation_epoch_start(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH: + raise RuntimeError("Trouble!") - class TroublemakerOnValidationStart(Callback): - def on_validation_start(self, trainer, pl_module): - if not trainer.sanity_checking: - raise RuntimeError("Trouble!") - model = BoringModel() - epoch_length = 2 - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback, TroublemakerOnValidationStart()], - max_epochs=5, - limit_train_batches=epoch_length, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") +class TroublemakerOnValidationEpochEnd(Callback): + def on_validation_epoch_end(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH: + raise RuntimeError("Trouble!") -def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_end(tmp_path): - """Test that the checkpoint is saved when an exception is raised in a callback on validation_end.""" +class TroublemakerOnValidationStart(Callback): + def on_validation_start(self, trainer, pl_module): + if not trainer.sanity_checking: + raise RuntimeError("Trouble!") - class TroublemakerOnValidationEnd(Callback): - def on_validation_end(self, trainer, pl_module): - if not trainer.sanity_checking: - raise RuntimeError("Trouble!") - model = BoringModel() - epoch_length = 2 - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()], - max_epochs=5, - limit_train_batches=epoch_length, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") +class TroublemakerOnValidationEnd(Callback): + def on_validation_end(self, trainer, pl_module): + if not trainer.sanity_checking: + raise RuntimeError("Trouble!") +@pytest.mark.parametrize( + ("TroubledCallback", "expected_checkpoint_global_step"), + [ + pytest.param( + TroublemakerOnTrainBatchStart, CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX, id="on_train_batch_start" + ), + pytest.param( + TroublemakerOnTrainBatchEnd, CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX + 1, id="on_train_batch_end" + ), + pytest.param( + TroublemakerOnTrainEpochStart, + CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, + id="on_train_epoch_start", + ), + pytest.param( + TroublemakerOnTrainEpochEnd, + (CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH + 1) * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, + id="on_train_epoch_end", + ), + pytest.param( + TroublemakerOnTrainEnd, + CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, + id="on_train_end", + ), + pytest.param( + TroublemakerOnValidationBatchStart, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_batch_start" + ), + pytest.param( + TroublemakerOnValidationBatchEnd, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_batch_end" + ), + pytest.param( + TroublemakerOnValidationEpochStart, + (CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH + 1) * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, + id="on_validation_epoch_start", + ), + pytest.param( + TroublemakerOnValidationEpochEnd, + (CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH + 1) * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, + id="on_validation_epoch_end", + ), + pytest.param(TroublemakerOnValidationStart, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_start"), + pytest.param(TroublemakerOnValidationEnd, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_end"), + ], +) @mock.patch("lightning.pytorch.callbacks.model_checkpoint.time") def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None: """Tests that the checkpoints are saved at the specified time interval.""" From 904bd743355571ae555e3e13c1ea0f6bc2fbb0b5 Mon Sep 17 00:00:00 2001 From: vsey Date: Fri, 20 Jun 2025 06:12:05 +0200 Subject: [PATCH 15/37] change doc string to reflect new default value for save on exception model checkpoint --- src/lightning/pytorch/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index c983c5276a80a..28111707c43f2 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -97,7 +97,7 @@ class ModelCheckpoint(Checkpoint): collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid collisions. - save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``True``. + save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``False``. mode: one of {min, max}. If ``save_top_k != 0``, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. From 3a3204e3c1713d517c03ac161dcd1d03fe27e33e Mon Sep 17 00:00:00 2001 From: vsey Date: Fri, 20 Jun 2025 07:21:39 +0200 Subject: [PATCH 16/37] checkpoint on exception add test function for exception in callback --- .../checkpointing/test_model_checkpoint.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index d23e77c2c4732..09b2daacb9975 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -928,6 +928,29 @@ def on_validation_end(self, trainer, pl_module): pytest.param(TroublemakerOnValidationEnd, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_end"), ], ) +def test_model_checkpoint_save_on_exception_in_other_callbacks( + tmp_path, TroubledCallback, expected_checkpoint_global_step +): + """Test that an checkpoint is saved when an exception is raised in an other callback.""" + + model = BoringModel() + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, TroubledCallback()], + max_epochs=CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS, + limit_train_batches=CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, + logger=False, + enable_progress_bar=False, + ) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + + assert os.path.isfile(tmp_path / f"step={expected_checkpoint_global_step}.ckpt") + checkpoint = torch.load(tmp_path / f"step={expected_checkpoint_global_step}.ckpt", weights_only=True) + assert checkpoint["global_step"] == expected_checkpoint_global_step + + @mock.patch("lightning.pytorch.callbacks.model_checkpoint.time") def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None: """Tests that the checkpoints are saved at the specified time interval.""" From 467c57bb1401ba661461bf86e395d2f361822393 Mon Sep 17 00:00:00 2001 From: vsey Date: Sat, 21 Jun 2025 01:01:31 +0200 Subject: [PATCH 17/37] add prefix option to generate checkpoint file name --- src/lightning/pytorch/callbacks/model_checkpoint.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 28111707c43f2..ee4d880165fb7 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -591,7 +591,11 @@ def _format_checkpoint_name( return filename def format_checkpoint_name( - self, metrics: dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None + self, + metrics: dict[str, Tensor], + filename: Optional[str] = None, + prefix: Optional[str] = None, + ver: Optional[int] = None, ) -> str: """Generate a filename according to the defined template. @@ -623,7 +627,9 @@ def format_checkpoint_name( """ filename = filename or self.filename - filename = self._format_checkpoint_name(filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name) + filename = self._format_checkpoint_name( + filename, metrics, prefix, auto_insert_metric_name=self.auto_insert_metric_name + ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) From 8ba638192c1458110c039d01265d518d7988b38d Mon Sep 17 00:00:00 2001 From: vsey Date: Sat, 21 Jun 2025 01:09:29 +0200 Subject: [PATCH 18/37] add exception prefix to checkpoints saved on exception --- src/lightning/pytorch/callbacks/model_checkpoint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index ee4d880165fb7..26bc39d25e860 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -214,6 +214,7 @@ class ModelCheckpoint(Checkpoint): CHECKPOINT_JOIN_CHAR = "-" CHECKPOINT_EQUALS_CHAR = "=" CHECKPOINT_NAME_LAST = "last" + CHECKPOINT_EXCEPTION_PREFIX = "exception" FILE_EXTENSION = ".ckpt" STARTING_VERSION = 1 @@ -345,7 +346,7 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: Exception) -> None: if self._should_save_on_exception(trainer): monitor_candidates = self._monitor_candidates(trainer) - filepath = self.format_checkpoint_name(metrics=monitor_candidates) + filepath = self.format_checkpoint_name(metrics=monitor_candidates, prefix=self.CHECKPOINT_EXCEPTION_PREFIX) self._save_checkpoint(trainer, filepath) self._save_last_checkpoint(trainer, monitor_candidates) rank_zero_info(f"An exception was raised saved checkpoint to {filepath}") From 3076ea1aac8cd8080c30b5728b316715e826f20f Mon Sep 17 00:00:00 2001 From: vsey Date: Sat, 21 Jun 2025 03:44:02 +0200 Subject: [PATCH 19/37] add test to test prefix for checkpoint name --- tests/tests_pytorch/checkpointing/test_model_checkpoint.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 09b2daacb9975..c2384f97f2c50 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -453,6 +453,12 @@ def test_model_checkpoint_format_checkpoint_name(tmp_path, monkeypatch): ckpt_name = ckpt.format_checkpoint_name({}, ver=3) assert ckpt_name == str(tmp_path / "name-v3.ckpt") + # with prefix + ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath=tmp_path, filename="name").format_checkpoint_name( + {}, prefix="test" + ) + assert ckpt_name == str(tmp_path / "test-name.ckpt") + # using slashes ckpt = ModelCheckpoint(monitor="early_stop_on", dirpath=None, filename="{epoch}_{val/loss:.5f}") ckpt_name = ckpt.format_checkpoint_name({"epoch": 4, "val/loss": 0.03}) From d78ea3e25efe108e872c79662423415f07d20fdf Mon Sep 17 00:00:00 2001 From: vsey Date: Sat, 21 Jun 2025 04:32:10 +0200 Subject: [PATCH 20/37] add test for exceptions at diffrent position in a model --- .../checkpointing/test_model_checkpoint.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index c2384f97f2c50..e8daee4606aff 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -770,6 +770,49 @@ def test_ckpt_every_n_train_steps(tmp_path): assert set(os.listdir(tmp_path)) == set(expected) +################################################################################################# + + +def test_model_checkpoint_on_exception(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a lightning module.""" + + class TroubledModelInTrainingStep(BoringModel): + def training_step(self, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + class TroubledModelInValidationStep(BoringModel): + def validation_step(self, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + models = [TroubledModelInTrainingStep(), TroubledModelInValidationStep()] + + for model in models: + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, filename=model.__class__.__name__, save_on_exception=True, every_n_epochs=4 + ) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback], + limit_train_batches=2, + max_epochs=5, + logger=False, + enable_progress_bar=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + + checkpoint_path = tmp_path / f"exception-{model.__class__.__name__}.ckpt" + + assert os.path.isfile(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + assert checkpoint["state_dict"] is not None + assert checkpoint["state_dict"] != {} + + +################################################################################################# def test_model_checkpoint_save_on_exception_in_training_step(tmp_path): """Test that the checkpoint is saved when an exception is raised in training_step.""" @@ -817,6 +860,8 @@ def validation_step(self, batch, batch_idx): assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") +################################################################################################# + CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX = 2 CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH = 21 CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS = 25 @@ -957,6 +1002,9 @@ def test_model_checkpoint_save_on_exception_in_other_callbacks( assert checkpoint["global_step"] == expected_checkpoint_global_step +################################################################################################# + + @mock.patch("lightning.pytorch.callbacks.model_checkpoint.time") def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None: """Tests that the checkpoints are saved at the specified time interval.""" From 42bbac123389a8520c0bbd498c109e18d827ffba Mon Sep 17 00:00:00 2001 From: vsey Date: Sat, 21 Jun 2025 04:55:47 +0200 Subject: [PATCH 21/37] add description to on exception hook in model checkpoint --- src/lightning/pytorch/callbacks/model_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 26bc39d25e860..3024875cd1471 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -344,6 +344,7 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul @override def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: Exception) -> None: + """Save a checkpoint when an exception is raised.""" if self._should_save_on_exception(trainer): monitor_candidates = self._monitor_candidates(trainer) filepath = self.format_checkpoint_name(metrics=monitor_candidates, prefix=self.CHECKPOINT_EXCEPTION_PREFIX) From c4b806311b7633b9b289a47158162430fb111c77 Mon Sep 17 00:00:00 2001 From: vsey Date: Sat, 21 Jun 2025 05:00:58 +0200 Subject: [PATCH 22/37] add test to check saving on exception in all relevalnt callback positions --- .../checkpointing/test_model_checkpoint.py | 247 ++++++------------ 1 file changed, 76 insertions(+), 171 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index e8daee4606aff..321a1ecac98f1 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -813,196 +813,101 @@ def validation_step(self, batch, batch_idx): ################################################################################################# -def test_model_checkpoint_save_on_exception_in_training_step(tmp_path): - """Test that the checkpoint is saved when an exception is raised in training_step.""" +def test_model_checkpoint_on_exception_in_other_callbacks(tmp_path): + """Test that an checkpoint is saved when an exception is raised in an other callback.""" - class TroubledModel(BoringModel): - def training_step(self, batch, batch_idx): + class TroubleMakerOnTrainBatchStart(Callback): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if batch_idx == 1: raise RuntimeError("Trouble!") - model = TroubledModel() - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback], - max_epochs=5, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - print(os.listdir(tmp_path)) - assert os.path.isfile(tmp_path / "step=1.ckpt") - - -def test_model_checkpoint_save_on_exception_in_validation_step(tmp_path): - """Test that the checkpoint is saved when an exception is raised in validation_step.""" - - class TroubledModel(BoringModel): - def validation_step(self, batch, batch_idx): - if not trainer.sanity_checking and batch_idx == 0: + class TroubleMakerOnTrainBatchEnd(Callback): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if batch_idx == 1: raise RuntimeError("Trouble!") - model = TroubledModel() - epoch_length = 2 - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback], - max_epochs=5, - limit_train_batches=epoch_length, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt") - - -################################################################################################# - -CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX = 2 -CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH = 21 -CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS = 25 -CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES = 4 -assert CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX < CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES -assert CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH < CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS - - -class TroublemakerOnTrainBatchStart(Callback): - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): - if batch_idx == CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX: - raise RuntimeError("Trouble!") - - -class TroublemakerOnTrainBatchEnd(Callback): - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - if batch_idx == CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX: - raise RuntimeError("Trouble!") - - -class TroublemakerOnTrainEpochStart(Callback): - def on_train_epoch_start(self, trainer, pl_module): - if trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH: - raise RuntimeError("Trouble!") - - -class TroublemakerOnTrainEpochEnd(Callback): - def on_train_epoch_end(self, trainer, pl_module): - if trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH: - raise RuntimeError("Trouble!") - - -class TroublemakerOnTrainEnd(Callback): - def on_train_end(self, trainer, pl_module): - raise RuntimeError("Trouble!") - - -class TroublemakerOnValidationBatchStart(Callback): - def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): - if not trainer.sanity_checking and batch_idx == 1: - raise RuntimeError("Trouble!") + class TroubleMakerOnTrainEpochStart(Callback): + def on_train_epoch_start(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + class TroubleMakerOnTrainEpochEnd(Callback): + def on_train_epoch_end(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") -class TroublemakerOnValidationBatchEnd(Callback): - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - if not trainer.sanity_checking and batch_idx == 1: + class TroubleMakerOnTrainEnd(Callback): + def on_train_end(self, trainer, pl_module): raise RuntimeError("Trouble!") + class TroubleMakerOnValidationBatchStart(Callback): + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") -class TroublemakerOnValidationEpochStart(Callback): - def on_validation_epoch_start(self, trainer, pl_module): - if not trainer.sanity_checking and trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH: - raise RuntimeError("Trouble!") - + class TroubleMakerOnValidationBatchEnd(Callback): + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") -class TroublemakerOnValidationEpochEnd(Callback): - def on_validation_epoch_end(self, trainer, pl_module): - if not trainer.sanity_checking and trainer.current_epoch == CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH: - raise RuntimeError("Trouble!") + class TroubleMakerOnValidationEpochStart(Callback): + def on_validation_epoch_start(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + class TroubleMakerOnValidationEpochEnd(Callback): + def on_validation_epoch_end(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == 1: + raise RuntimeError("Trouble!") -class TroublemakerOnValidationStart(Callback): - def on_validation_start(self, trainer, pl_module): - if not trainer.sanity_checking: - raise RuntimeError("Trouble!") + class TroubleMakerOnValidationStart(Callback): + def on_validation_start(self, trainer, pl_module): + if not trainer.sanity_checking: + raise RuntimeError("Trouble!") + class TroubleMakerOnValidationEnd(Callback): + def on_validation_end(self, trainer, pl_module): + if not trainer.sanity_checking: + raise RuntimeError("Trouble!") -class TroublemakerOnValidationEnd(Callback): - def on_validation_end(self, trainer, pl_module): - if not trainer.sanity_checking: + class TroubleMakerOnFitEnd(Callback): + def on_fit_end(self, trainer, pl_module): raise RuntimeError("Trouble!") + troubled_callbacks = [ + TroubleMakerOnTrainBatchStart(), + TroubleMakerOnTrainBatchEnd(), + TroubleMakerOnTrainEpochStart(), + TroubleMakerOnTrainEpochEnd(), + TroubleMakerOnTrainEnd(), + TroubleMakerOnValidationBatchStart(), + TroubleMakerOnValidationBatchEnd(), + TroubleMakerOnValidationEpochStart(), + TroubleMakerOnValidationEpochEnd(), + TroubleMakerOnValidationStart(), + TroubleMakerOnValidationEnd(), + TroubleMakerOnFitEnd(), + ] -@pytest.mark.parametrize( - ("TroubledCallback", "expected_checkpoint_global_step"), - [ - pytest.param( - TroublemakerOnTrainBatchStart, CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX, id="on_train_batch_start" - ), - pytest.param( - TroublemakerOnTrainBatchEnd, CHECKPOINT_ON_EXCEPTION_RAISE_AT_BATCH_IDX + 1, id="on_train_batch_end" - ), - pytest.param( - TroublemakerOnTrainEpochStart, - CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, - id="on_train_epoch_start", - ), - pytest.param( - TroublemakerOnTrainEpochEnd, - (CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH + 1) * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, - id="on_train_epoch_end", - ), - pytest.param( - TroublemakerOnTrainEnd, - CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, - id="on_train_end", - ), - pytest.param( - TroublemakerOnValidationBatchStart, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_batch_start" - ), - pytest.param( - TroublemakerOnValidationBatchEnd, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_batch_end" - ), - pytest.param( - TroublemakerOnValidationEpochStart, - (CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH + 1) * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, - id="on_validation_epoch_start", - ), - pytest.param( - TroublemakerOnValidationEpochEnd, - (CHECKPOINT_ON_EXCEPTION_RAISE_AT_EPOCH + 1) * CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, - id="on_validation_epoch_end", - ), - pytest.param(TroublemakerOnValidationStart, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_start"), - pytest.param(TroublemakerOnValidationEnd, CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, id="on_validation_end"), - ], -) -def test_model_checkpoint_save_on_exception_in_other_callbacks( - tmp_path, TroubledCallback, expected_checkpoint_global_step -): - """Test that an checkpoint is saved when an exception is raised in an other callback.""" - - model = BoringModel() - checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback, TroubledCallback()], - max_epochs=CHECKPOINT_ON_EXCEPTION_MAX_EPOCHS, - limit_train_batches=CHECKPOINT_ON_EXCEPTION_TRAIN_BATCHES, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - - assert os.path.isfile(tmp_path / f"step={expected_checkpoint_global_step}.ckpt") - checkpoint = torch.load(tmp_path / f"step={expected_checkpoint_global_step}.ckpt", weights_only=True) - assert checkpoint["global_step"] == expected_checkpoint_global_step - - -################################################################################################# + for troubled_callback in troubled_callbacks: + model = BoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, filename=troubled_callback.__class__.__name__, save_on_exception=True, every_n_epochs=5 + ) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, troubled_callback], + max_epochs=4, + limit_train_batches=2, + logger=False, + enable_progress_bar=False, + ) + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert os.path.isfile(tmp_path / f"exception-{troubled_callback.__class__.__name__}.ckpt") + checkpoint = torch.load(tmp_path / f"exception-{troubled_callback.__class__.__name__}.ckpt", map_location="cpu") + assert checkpoint["state_dict"] is not None + assert checkpoint["state_dict"] != {} @mock.patch("lightning.pytorch.callbacks.model_checkpoint.time") From 2ca6dab58f9bfafd0a8872a8c8a7574232ce4900 Mon Sep 17 00:00:00 2001 From: vsey Date: Sat, 21 Jun 2025 06:29:49 +0200 Subject: [PATCH 23/37] also print exception when saving checkpoint --- src/lightning/pytorch/callbacks/model_checkpoint.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 3024875cd1471..207d938e09a83 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -350,7 +350,10 @@ def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", e filepath = self.format_checkpoint_name(metrics=monitor_candidates, prefix=self.CHECKPOINT_EXCEPTION_PREFIX) self._save_checkpoint(trainer, filepath) self._save_last_checkpoint(trainer, monitor_candidates) - rank_zero_info(f"An exception was raised saved checkpoint to {filepath}") + rank_zero_info( + f"An {type(exception).__name__} was raised with message: \ + {str(exception)}, saved checkpoint to {filepath}" + ) @override def state_dict(self) -> dict[str, Any]: From 9e9e580e152460a9f8c6675805738b377566464e Mon Sep 17 00:00:00 2001 From: vsey Date: Sat, 21 Jun 2025 06:30:34 +0200 Subject: [PATCH 24/37] test checkpointing on exception in varoius model steps --- .../checkpointing/test_model_checkpoint.py | 105 +++++++++++++++++- 1 file changed, 101 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 321a1ecac98f1..f94487ea88fdf 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -776,27 +776,124 @@ def test_ckpt_every_n_train_steps(tmp_path): def test_model_checkpoint_on_exception(tmp_path): """Test that the checkpoint is saved when an exception is raised in a lightning module.""" + class TroubledModelOnTrainEpochStart(BoringModel): + def on_train_epoch_start(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnTrainBatchStart(BoringModel): + def on_train_batch_start(self, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + class TroubledModelInTrainingStep(BoringModel): def training_step(self, batch, batch_idx): if batch_idx == 1: raise RuntimeError("Trouble!") + class TroubledModelOnBeforeZeroGrad(BoringModel): + def on_before_zero_grad(self, optimizer): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnBeforeBackward(BoringModel): + def on_before_backward(self, loss): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnAfterBackward(BoringModel): + def on_after_backward(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnBeforeOptimizerStep(BoringModel): + def on_before_optimizer_step(self, optimizer): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnTrainBatchEnd(BoringModel): + def on_train_batch_end(self, outputs, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnTrainEpochEnd(BoringModel): + def on_train_epoch_end(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnTrainEnd(BoringModel): + def on_train_end(self): + raise RuntimeError("Trouble!") + + class TroubledModelOnValidationStart(BoringModel): + def on_validation_start(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnValidationEpochStart(BoringModel): + def on_validation_epoch_start(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnValidationBatchStart(BoringModel): + def on_validation_batch_start(self, batch, batch_idx): + if not self.trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + class TroubledModelInValidationStep(BoringModel): def validation_step(self, batch, batch_idx): - if not trainer.sanity_checking and batch_idx == 1: + if not self.trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnValidationBatchEnd(BoringModel): + def on_validation_batch_end(self, outputs, batch, batch_idx): + if not self.trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnValidationEpochEnd(BoringModel): + def on_validation_epoch_end(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnValidationEnd(BoringModel): + def on_validation_end(self): + if not self.trainer.sanity_checking: raise RuntimeError("Trouble!") - models = [TroubledModelInTrainingStep(), TroubledModelInValidationStep()] + class TroubledModelOnFitEnd(BoringModel): + def on_fit_end(self): + raise RuntimeError("Trouble!") + + models = [ + TroubledModelOnTrainEpochStart(), + TroubledModelOnTrainBatchStart(), + TroubledModelInTrainingStep(), + TroubledModelOnBeforeZeroGrad(), + TroubledModelOnBeforeBackward(), + TroubledModelOnAfterBackward(), + TroubledModelOnBeforeOptimizerStep(), + TroubledModelOnTrainBatchEnd(), + TroubledModelOnTrainEpochEnd(), + TroubledModelOnTrainEnd(), + TroubledModelOnValidationStart(), + TroubledModelOnValidationEpochStart(), + TroubledModelOnValidationBatchStart(), + TroubledModelInValidationStep(), + TroubledModelOnValidationBatchEnd(), + TroubledModelOnValidationEpochEnd(), + TroubledModelOnValidationEnd(), + TroubledModelOnFitEnd(), + ] for model in models: checkpoint_callback = ModelCheckpoint( - dirpath=tmp_path, filename=model.__class__.__name__, save_on_exception=True, every_n_epochs=4 + dirpath=tmp_path, filename=model.__class__.__name__, save_on_exception=True, every_n_epochs=5 ) trainer = Trainer( default_root_dir=tmp_path, callbacks=[checkpoint_callback], limit_train_batches=2, - max_epochs=5, + max_epochs=4, logger=False, enable_progress_bar=False, ) From d2f74e984eba6ee30d2d2ce680f160771f76238f Mon Sep 17 00:00:00 2001 From: vsey Date: Sat, 21 Jun 2025 06:31:51 +0200 Subject: [PATCH 25/37] remove deviders in test_model_checkpoint --- tests/tests_pytorch/checkpointing/test_model_checkpoint.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index f94487ea88fdf..810d854b644ec 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -770,9 +770,6 @@ def test_ckpt_every_n_train_steps(tmp_path): assert set(os.listdir(tmp_path)) == set(expected) -################################################################################################# - - def test_model_checkpoint_on_exception(tmp_path): """Test that the checkpoint is saved when an exception is raised in a lightning module.""" @@ -909,7 +906,6 @@ def on_fit_end(self): assert checkpoint["state_dict"] != {} -################################################################################################# def test_model_checkpoint_on_exception_in_other_callbacks(tmp_path): """Test that an checkpoint is saved when an exception is raised in an other callback.""" From ac33670c3af13988f21c17bb70c020bfd1893682 Mon Sep 17 00:00:00 2001 From: vsey Date: Sat, 21 Jun 2025 23:28:16 +0200 Subject: [PATCH 26/37] add test for run conditions for save checkpoint on exception --- .../checkpointing/test_model_checkpoint.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 810d854b644ec..248bcc739b2c9 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -770,6 +770,73 @@ def test_ckpt_every_n_train_steps(tmp_path): assert set(os.listdir(tmp_path)) == set(expected) +def test_model_checkpoint_on_exception_run_condition(tmp_path): + """Test that the checkpoint is saved when an exception is raised in a lightning module.""" + + # Don't save checkpoint if sanity check fails + class TroubledModelSanityCheck(BoringModel): + def on_validation_start(self) -> None: + if self.trainer.sanity_checking: + print("Trouble!") + raise RuntimeError("Trouble!") + + model = TroubledModelSanityCheck() + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="sanity_check", save_on_exception=True) + trainer = Trainer( + default_root_dir=tmp_path, + num_sanity_val_steps=4, + limit_train_batches=2, + callbacks=[checkpoint_callback], + max_epochs=2, + logger=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert not os.path.isfile(tmp_path / "exception-sanity_check.ckpt") + + # Don't save checkpoint if fast dev run fails + class TroubledModelFastDevRun(BoringModel): + def on_train_batch_start(self, batch, batch_idx) -> None: + if self.trainer.fast_dev_run and batch_idx == 1: + raise RuntimeError("Trouble!") + + model = TroubledModelFastDevRun() + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="fast_dev_run", save_on_exception=True) + trainer = Trainer( + default_root_dir=tmp_path, + fast_dev_run=2, + limit_train_batches=2, + callbacks=[checkpoint_callback], + max_epochs=2, + logger=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert not os.path.isfile(tmp_path / "exception-fast_dev_run.ckpt") + + # Don't save checkpoint if already saved a checkpoint + class TroubledModelAlreadySavedCheckpoint(BoringModel): + def on_train_batch_start(self, batch, batch_idx) -> None: + if self.trainer.global_step == 1: + raise RuntimeError("Trouble!") + + model = TroubledModelAlreadySavedCheckpoint() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, filename="already_saved", save_on_exception=True, every_n_train_steps=1 + ) + trainer = Trainer( + default_root_dir=tmp_path, limit_train_batches=2, callbacks=[checkpoint_callback], max_epochs=2, logger=False + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + + assert not os.path.isfile(tmp_path / "exception-already_saved.ckpt") + assert os.path.isfile(tmp_path / "already_saved.ckpt") + + def test_model_checkpoint_on_exception(tmp_path): """Test that the checkpoint is saved when an exception is raised in a lightning module.""" From 0dc9f244f153f0855529aeeceeb6a4d2c68ed128 Mon Sep 17 00:00:00 2001 From: vsey <54716634+vsey@users.noreply.github.com> Date: Mon, 23 Jun 2025 20:30:48 +0200 Subject: [PATCH 27/37] Update model_checkpoint on_exception run condition to follow common convention Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- .../pytorch/callbacks/model_checkpoint.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 207d938e09a83..668e91abd0dfc 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -345,15 +345,16 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul @override def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: Exception) -> None: """Save a checkpoint when an exception is raised.""" - if self._should_save_on_exception(trainer): - monitor_candidates = self._monitor_candidates(trainer) - filepath = self.format_checkpoint_name(metrics=monitor_candidates, prefix=self.CHECKPOINT_EXCEPTION_PREFIX) - self._save_checkpoint(trainer, filepath) - self._save_last_checkpoint(trainer, monitor_candidates) - rank_zero_info( - f"An {type(exception).__name__} was raised with message: \ - {str(exception)}, saved checkpoint to {filepath}" - ) + if not self._should_save_on_exception(trainer): + return + monitor_candidates = self._monitor_candidates(trainer) + filepath = self.format_checkpoint_name(metrics=monitor_candidates, prefix=self.CHECKPOINT_EXCEPTION_PREFIX) + self._save_checkpoint(trainer, filepath) + self._save_last_checkpoint(trainer, monitor_candidates) + rank_zero_info( + f"An {type(exception).__name__} was raised with message: \ + {str(exception)}, saved checkpoint to {filepath}" + ) @override def state_dict(self) -> dict[str, Any]: From c62f0e130149ced3063b435dd4bd47f7a2cdbefe Mon Sep 17 00:00:00 2001 From: vsey Date: Mon, 23 Jun 2025 20:55:27 +0200 Subject: [PATCH 28/37] add checkpoint path variable in test_on_exceptio_incallback and set weights_only to False when loading checkpoint to because default will change in future --- .../tests_pytorch/checkpointing/test_model_checkpoint.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 248bcc739b2c9..0f3316582d0e2 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -968,7 +968,7 @@ def on_fit_end(self): checkpoint_path = tmp_path / f"exception-{model.__class__.__name__}.ckpt" assert os.path.isfile(checkpoint_path) - checkpoint = torch.load(checkpoint_path, map_location="cpu") + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) assert checkpoint["state_dict"] is not None assert checkpoint["state_dict"] != {} @@ -1064,8 +1064,11 @@ def on_fit_end(self, trainer, pl_module): ) with pytest.raises(RuntimeError, match="Trouble!"): trainer.fit(model) - assert os.path.isfile(tmp_path / f"exception-{troubled_callback.__class__.__name__}.ckpt") - checkpoint = torch.load(tmp_path / f"exception-{troubled_callback.__class__.__name__}.ckpt", map_location="cpu") + + checkpoint_path = tmp_path / f"exception-{troubled_callback.__class__.__name__}.ckpt" + + assert os.path.isfile(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) assert checkpoint["state_dict"] is not None assert checkpoint["state_dict"] != {} From f16b771b24be5ba4e2fa21612405cd470ccd6dd1 Mon Sep 17 00:00:00 2001 From: vsey Date: Mon, 23 Jun 2025 21:09:57 +0200 Subject: [PATCH 29/37] remove exception prefix as this is already covert by last flag --- src/lightning/pytorch/callbacks/model_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 668e91abd0dfc..98ec77aa3bad3 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -214,7 +214,6 @@ class ModelCheckpoint(Checkpoint): CHECKPOINT_JOIN_CHAR = "-" CHECKPOINT_EQUALS_CHAR = "=" CHECKPOINT_NAME_LAST = "last" - CHECKPOINT_EXCEPTION_PREFIX = "exception" FILE_EXTENSION = ".ckpt" STARTING_VERSION = 1 @@ -348,7 +347,7 @@ def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", e if not self._should_save_on_exception(trainer): return monitor_candidates = self._monitor_candidates(trainer) - filepath = self.format_checkpoint_name(metrics=monitor_candidates, prefix=self.CHECKPOINT_EXCEPTION_PREFIX) + filepath = self.format_checkpoint_name(metrics=monitor_candidates) self._save_checkpoint(trainer, filepath) self._save_last_checkpoint(trainer, monitor_candidates) rank_zero_info( From 09ba24fba0e727ed1c7bddac5873d2ced4f82f72 Mon Sep 17 00:00:00 2001 From: vsey Date: Mon, 23 Jun 2025 21:31:07 +0200 Subject: [PATCH 30/37] use pytest parametrization for testing callback on exception --- .../checkpointing/test_model_checkpoint.py | 415 ++++++++++-------- 1 file changed, 226 insertions(+), 189 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 0f3316582d0e2..f154b8a7e6b29 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -837,240 +837,277 @@ def on_train_batch_start(self, batch, batch_idx) -> None: assert os.path.isfile(tmp_path / "already_saved.ckpt") -def test_model_checkpoint_on_exception(tmp_path): - """Test that the checkpoint is saved when an exception is raised in a lightning module.""" +class TroubledModelOnTrainEpochStart(BoringModel): + def on_train_epoch_start(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") - class TroubledModelOnTrainEpochStart(BoringModel): - def on_train_epoch_start(self): - if self.current_epoch == 1: - raise RuntimeError("Trouble!") - class TroubledModelOnTrainBatchStart(BoringModel): - def on_train_batch_start(self, batch, batch_idx): - if batch_idx == 1: - raise RuntimeError("Trouble!") +class TroubledModelOnTrainBatchStart(BoringModel): + def on_train_batch_start(self, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") - class TroubledModelInTrainingStep(BoringModel): - def training_step(self, batch, batch_idx): - if batch_idx == 1: - raise RuntimeError("Trouble!") - class TroubledModelOnBeforeZeroGrad(BoringModel): - def on_before_zero_grad(self, optimizer): - if self.current_epoch == 1: - raise RuntimeError("Trouble!") +class TroubledModelInTrainingStep(BoringModel): + def training_step(self, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") - class TroubledModelOnBeforeBackward(BoringModel): - def on_before_backward(self, loss): - if self.current_epoch == 1: - raise RuntimeError("Trouble!") - class TroubledModelOnAfterBackward(BoringModel): - def on_after_backward(self): - if self.current_epoch == 1: - raise RuntimeError("Trouble!") +class TroubledModelOnBeforeZeroGrad(BoringModel): + def on_before_zero_grad(self, optimizer): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") - class TroubledModelOnBeforeOptimizerStep(BoringModel): - def on_before_optimizer_step(self, optimizer): - if self.current_epoch == 1: - raise RuntimeError("Trouble!") - class TroubledModelOnTrainBatchEnd(BoringModel): - def on_train_batch_end(self, outputs, batch, batch_idx): - if batch_idx == 1: - raise RuntimeError("Trouble!") +class TroubledModelOnBeforeBackward(BoringModel): + def on_before_backward(self, loss): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") - class TroubledModelOnTrainEpochEnd(BoringModel): - def on_train_epoch_end(self): - if self.current_epoch == 1: - raise RuntimeError("Trouble!") - class TroubledModelOnTrainEnd(BoringModel): - def on_train_end(self): +class TroubledModelOnAfterBackward(BoringModel): + def on_after_backward(self): + if self.current_epoch == 1: raise RuntimeError("Trouble!") - class TroubledModelOnValidationStart(BoringModel): - def on_validation_start(self): - if not self.trainer.sanity_checking and self.current_epoch == 1: - raise RuntimeError("Trouble!") - class TroubledModelOnValidationEpochStart(BoringModel): - def on_validation_epoch_start(self): - if not self.trainer.sanity_checking and self.current_epoch == 1: - raise RuntimeError("Trouble!") +class TroubledModelOnBeforeOptimizerStep(BoringModel): + def on_before_optimizer_step(self, optimizer): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") - class TroubledModelOnValidationBatchStart(BoringModel): - def on_validation_batch_start(self, batch, batch_idx): - if not self.trainer.sanity_checking and batch_idx == 1: - raise RuntimeError("Trouble!") - class TroubledModelInValidationStep(BoringModel): - def validation_step(self, batch, batch_idx): - if not self.trainer.sanity_checking and batch_idx == 1: - raise RuntimeError("Trouble!") +class TroubledModelOnTrainBatchEnd(BoringModel): + def on_train_batch_end(self, outputs, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") - class TroubledModelOnValidationBatchEnd(BoringModel): - def on_validation_batch_end(self, outputs, batch, batch_idx): - if not self.trainer.sanity_checking and batch_idx == 1: - raise RuntimeError("Trouble!") - class TroubledModelOnValidationEpochEnd(BoringModel): - def on_validation_epoch_end(self): - if not self.trainer.sanity_checking and self.current_epoch == 1: - raise RuntimeError("Trouble!") +class TroubledModelOnTrainEpochEnd(BoringModel): + def on_train_epoch_end(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") - class TroubledModelOnValidationEnd(BoringModel): - def on_validation_end(self): - if not self.trainer.sanity_checking: - raise RuntimeError("Trouble!") - class TroubledModelOnFitEnd(BoringModel): - def on_fit_end(self): +class TroubledModelOnTrainEnd(BoringModel): + def on_train_end(self): + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationStart(BoringModel): + def on_validation_start(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: raise RuntimeError("Trouble!") - models = [ - TroubledModelOnTrainEpochStart(), - TroubledModelOnTrainBatchStart(), - TroubledModelInTrainingStep(), - TroubledModelOnBeforeZeroGrad(), - TroubledModelOnBeforeBackward(), - TroubledModelOnAfterBackward(), - TroubledModelOnBeforeOptimizerStep(), - TroubledModelOnTrainBatchEnd(), - TroubledModelOnTrainEpochEnd(), - TroubledModelOnTrainEnd(), - TroubledModelOnValidationStart(), - TroubledModelOnValidationEpochStart(), - TroubledModelOnValidationBatchStart(), - TroubledModelInValidationStep(), - TroubledModelOnValidationBatchEnd(), - TroubledModelOnValidationEpochEnd(), - TroubledModelOnValidationEnd(), - TroubledModelOnFitEnd(), - ] - for model in models: - checkpoint_callback = ModelCheckpoint( - dirpath=tmp_path, filename=model.__class__.__name__, save_on_exception=True, every_n_epochs=5 - ) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback], - limit_train_batches=2, - max_epochs=4, - logger=False, - enable_progress_bar=False, - ) +class TroubledModelOnValidationEpochStart(BoringModel): + def on_validation_epoch_start(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) - checkpoint_path = tmp_path / f"exception-{model.__class__.__name__}.ckpt" +class TroubledModelOnValidationBatchStart(BoringModel): + def on_validation_batch_start(self, batch, batch_idx): + if not self.trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") - assert os.path.isfile(checkpoint_path) - checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) - assert checkpoint["state_dict"] is not None - assert checkpoint["state_dict"] != {} +class TroubledModelInValidationStep(BoringModel): + def validation_step(self, batch, batch_idx): + if not self.trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") -def test_model_checkpoint_on_exception_in_other_callbacks(tmp_path): - """Test that an checkpoint is saved when an exception is raised in an other callback.""" - class TroubleMakerOnTrainBatchStart(Callback): - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): - if batch_idx == 1: - raise RuntimeError("Trouble!") +class TroubledModelOnValidationBatchEnd(BoringModel): + def on_validation_batch_end(self, outputs, batch, batch_idx): + if not self.trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") - class TroubleMakerOnTrainBatchEnd(Callback): - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - if batch_idx == 1: - raise RuntimeError("Trouble!") - class TroubleMakerOnTrainEpochStart(Callback): - def on_train_epoch_start(self, trainer, pl_module): - if trainer.current_epoch == 1: - raise RuntimeError("Trouble!") +class TroubledModelOnValidationEpochEnd(BoringModel): + def on_validation_epoch_end(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") - class TroubleMakerOnTrainEpochEnd(Callback): - def on_train_epoch_end(self, trainer, pl_module): - if trainer.current_epoch == 1: - raise RuntimeError("Trouble!") - class TroubleMakerOnTrainEnd(Callback): - def on_train_end(self, trainer, pl_module): +class TroubledModelOnValidationEnd(BoringModel): + def on_validation_end(self): + if not self.trainer.sanity_checking: raise RuntimeError("Trouble!") - class TroubleMakerOnValidationBatchStart(Callback): - def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): - if not trainer.sanity_checking and batch_idx == 1: - raise RuntimeError("Trouble!") - class TroubleMakerOnValidationBatchEnd(Callback): - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - if not trainer.sanity_checking and batch_idx == 1: - raise RuntimeError("Trouble!") +class TroubledModelOnFitEnd(BoringModel): + def on_fit_end(self): + raise RuntimeError("Trouble!") - class TroubleMakerOnValidationEpochStart(Callback): - def on_validation_epoch_start(self, trainer, pl_module): - if not trainer.sanity_checking and trainer.current_epoch == 1: - raise RuntimeError("Trouble!") - class TroubleMakerOnValidationEpochEnd(Callback): - def on_validation_epoch_end(self, trainer, pl_module): - if not trainer.sanity_checking and trainer.current_epoch == 1: - raise RuntimeError("Trouble!") +@pytest.mark.parametrize( + "TroubledModel", + [ + TroubledModelOnTrainEpochStart, + TroubledModelOnTrainBatchStart, + TroubledModelInTrainingStep, + TroubledModelOnBeforeZeroGrad, + TroubledModelOnBeforeBackward, + TroubledModelOnAfterBackward, + TroubledModelOnBeforeOptimizerStep, + TroubledModelOnTrainBatchEnd, + TroubledModelOnTrainEpochEnd, + TroubledModelOnTrainEnd, + TroubledModelOnValidationStart, + TroubledModelOnValidationEpochStart, + TroubledModelOnValidationBatchStart, + TroubledModelInValidationStep, + TroubledModelOnValidationBatchEnd, + TroubledModelOnValidationEpochEnd, + TroubledModelOnValidationEnd, + TroubledModelOnFitEnd, + ], +) +def test_model_checkpoint_on_exception_parametrized(tmp_path, TroubledModel): + """Test that the checkpoint is saved when an exception is raised in a lightning module.""" + model = TroubledModel() - class TroubleMakerOnValidationStart(Callback): - def on_validation_start(self, trainer, pl_module): - if not trainer.sanity_checking: - raise RuntimeError("Trouble!") + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, filename="exception", save_on_exception=True, every_n_epochs=7 + ) + + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback], + limit_train_batches=2, + max_epochs=4, + logger=False, + enable_progress_bar=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + + checkpoint_path = tmp_path / "exception.ckpt" + + assert os.path.isfile(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + assert checkpoint["state_dict"] is not None + assert checkpoint["state_dict"] != {} - class TroubleMakerOnValidationEnd(Callback): - def on_validation_end(self, trainer, pl_module): - if not trainer.sanity_checking: - raise RuntimeError("Trouble!") - class TroubleMakerOnFitEnd(Callback): - def on_fit_end(self, trainer, pl_module): +class TroubledCallbackOnTrainBatchStart(Callback): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + if batch_idx == 1: raise RuntimeError("Trouble!") - troubled_callbacks = [ - TroubleMakerOnTrainBatchStart(), - TroubleMakerOnTrainBatchEnd(), - TroubleMakerOnTrainEpochStart(), - TroubleMakerOnTrainEpochEnd(), - TroubleMakerOnTrainEnd(), - TroubleMakerOnValidationBatchStart(), - TroubleMakerOnValidationBatchEnd(), - TroubleMakerOnValidationEpochStart(), - TroubleMakerOnValidationEpochEnd(), - TroubleMakerOnValidationStart(), - TroubleMakerOnValidationEnd(), - TroubleMakerOnFitEnd(), - ] - for troubled_callback in troubled_callbacks: - model = BoringModel() - checkpoint_callback = ModelCheckpoint( - dirpath=tmp_path, filename=troubled_callback.__class__.__name__, save_on_exception=True, every_n_epochs=5 - ) - trainer = Trainer( - default_root_dir=tmp_path, - callbacks=[checkpoint_callback, troubled_callback], - max_epochs=4, - limit_train_batches=2, - logger=False, - enable_progress_bar=False, - ) - with pytest.raises(RuntimeError, match="Trouble!"): - trainer.fit(model) +class TroubledCallbackOnTrainBatchEnd(Callback): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainEpochStart(Callback): + def on_train_epoch_start(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainEpochEnd(Callback): + def on_train_epoch_end(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainEnd(Callback): + def on_train_end(self, trainer, pl_module): + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationBatchStart(Callback): + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationBatchEnd(Callback): + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationEpochStart(Callback): + def on_validation_epoch_start(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationEpochEnd(Callback): + def on_validation_epoch_end(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationStart(Callback): + def on_validation_start(self, trainer, pl_module): + if not trainer.sanity_checking: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationEnd(Callback): + def on_validation_end(self, trainer, pl_module): + if not trainer.sanity_checking: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnFitEnd(Callback): + def on_fit_end(self, trainer, pl_module): + raise RuntimeError("Trouble!") + + +@pytest.mark.parametrize( + "TroubledCallback", + [ + TroubledCallbackOnTrainBatchStart, + TroubledCallbackOnTrainBatchEnd, + TroubledCallbackOnTrainEpochStart, + TroubledCallbackOnTrainEpochEnd, + TroubledCallbackOnTrainEnd, + TroubledCallbackOnValidationBatchStart, + TroubledCallbackOnValidationBatchEnd, + TroubledCallbackOnValidationEpochStart, + TroubledCallbackOnValidationEpochEnd, + TroubledCallbackOnValidationStart, + TroubledCallbackOnValidationEnd, + TroubledCallbackOnFitEnd, + ], +) +def test_model_checkpoint_on_exception_in_other_callbacks(tmp_path, TroubledCallback): + """Test that an checkpoint is saved when an exception is raised in an other callback.""" + + model = BoringModel() + troubled_callback = TroubledCallback() + + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, filename="exception", save_on_exception=True, every_n_epochs=7 + ) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, troubled_callback], + max_epochs=4, + limit_train_batches=2, + logger=False, + enable_progress_bar=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) - checkpoint_path = tmp_path / f"exception-{troubled_callback.__class__.__name__}.ckpt" + checkpoint_path = tmp_path / "exception.ckpt" - assert os.path.isfile(checkpoint_path) - checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) - assert checkpoint["state_dict"] is not None - assert checkpoint["state_dict"] != {} + assert os.path.isfile(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + assert checkpoint["state_dict"] is not None + assert checkpoint["state_dict"] != {} @mock.patch("lightning.pytorch.callbacks.model_checkpoint.time") From c921875d93f12c10b17df78b684ec02a6029b417 Mon Sep 17 00:00:00 2001 From: vsey Date: Mon, 23 Jun 2025 23:20:20 +0200 Subject: [PATCH 31/37] add missing callback hooks for Test Troubled Callback and order them according to documentation --- .../checkpointing/test_model_checkpoint.py | 71 +++++++++++++------ 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index f154b8a7e6b29..43db0788c4221 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -994,6 +994,11 @@ def test_model_checkpoint_on_exception_parametrized(tmp_path, TroubledModel): assert checkpoint["state_dict"] != {} +class TroubledCallbackOnFitEnd(Callback): + def on_fit_end(self, trainer, pl_module): + raise RuntimeError("Trouble!") + + class TroubledCallbackOnTrainBatchStart(Callback): def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if batch_idx == 1: @@ -1018,9 +1023,16 @@ def on_train_epoch_end(self, trainer, pl_module): raise RuntimeError("Trouble!") -class TroubledCallbackOnTrainEnd(Callback): - def on_train_end(self, trainer, pl_module): - raise RuntimeError("Trouble!") +class TroubledCallbackOnValidationEpochStart(Callback): + def on_validation_epoch_start(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationEpochEnd(Callback): + def on_validation_epoch_end(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == 1: + raise RuntimeError("Trouble!") class TroubledCallbackOnValidationBatchStart(Callback): @@ -1035,16 +1047,9 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) raise RuntimeError("Trouble!") -class TroubledCallbackOnValidationEpochStart(Callback): - def on_validation_epoch_start(self, trainer, pl_module): - if not trainer.sanity_checking and trainer.current_epoch == 1: - raise RuntimeError("Trouble!") - - -class TroubledCallbackOnValidationEpochEnd(Callback): - def on_validation_epoch_end(self, trainer, pl_module): - if not trainer.sanity_checking and trainer.current_epoch == 1: - raise RuntimeError("Trouble!") +class TroubledCallbackOnTrainEnd(Callback): + def on_train_end(self, trainer, pl_module): + raise RuntimeError("Trouble!") class TroubledCallbackOnValidationStart(Callback): @@ -1059,26 +1064,52 @@ def on_validation_end(self, trainer, pl_module): raise RuntimeError("Trouble!") -class TroubledCallbackOnFitEnd(Callback): - def on_fit_end(self, trainer, pl_module): - raise RuntimeError("Trouble!") +class TroubleCallbackOnBeforeBackward(Callback): + def on_before_backward(self, trainer, pl_module, loss): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubleCallbackOnAfterBackward(Callback): + def on_after_backward(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubleCallbackOnBeforeOptimizerStep(Callback): + def on_before_optimizer_step(self, trainer, pl_module, optimizer): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubleCallbackOnBeforeZeroGrad(Callback): + def on_before_zero_grad(self, trainer, pl_module, optimizer): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +#### @pytest.mark.parametrize( "TroubledCallback", [ + TroubledCallbackOnFitEnd, TroubledCallbackOnTrainBatchStart, TroubledCallbackOnTrainBatchEnd, TroubledCallbackOnTrainEpochStart, TroubledCallbackOnTrainEpochEnd, - TroubledCallbackOnTrainEnd, - TroubledCallbackOnValidationBatchStart, - TroubledCallbackOnValidationBatchEnd, TroubledCallbackOnValidationEpochStart, TroubledCallbackOnValidationEpochEnd, + TroubledCallbackOnValidationBatchStart, + TroubledCallbackOnValidationBatchEnd, + TroubledCallbackOnTrainEnd, TroubledCallbackOnValidationStart, TroubledCallbackOnValidationEnd, - TroubledCallbackOnFitEnd, + TroubleCallbackOnBeforeBackward, + TroubleCallbackOnAfterBackward, + TroubleCallbackOnBeforeOptimizerStep, + TroubleCallbackOnBeforeZeroGrad, ], ) def test_model_checkpoint_on_exception_in_other_callbacks(tmp_path, TroubledCallback): From fd3de65abcb286dc103342c4cd6a31d655c6a837 Mon Sep 17 00:00:00 2001 From: vsey Date: Mon, 23 Jun 2025 23:43:56 +0200 Subject: [PATCH 32/37] add missing test hooks of lighning module to test save chekpoint on exception --- .../checkpointing/test_model_checkpoint.py | 141 ++++++++++++------ 1 file changed, 92 insertions(+), 49 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 43db0788c4221..9eeb3381d75e6 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -837,26 +837,20 @@ def on_train_batch_start(self, batch, batch_idx) -> None: assert os.path.isfile(tmp_path / "already_saved.ckpt") -class TroubledModelOnTrainEpochStart(BoringModel): - def on_train_epoch_start(self): - if self.current_epoch == 1: - raise RuntimeError("Trouble!") - - -class TroubledModelOnTrainBatchStart(BoringModel): - def on_train_batch_start(self, batch, batch_idx): +class TroubledModelInTrainingStep(BoringModel): + def training_step(self, batch, batch_idx): if batch_idx == 1: raise RuntimeError("Trouble!") -class TroubledModelInTrainingStep(BoringModel): - def training_step(self, batch, batch_idx): - if batch_idx == 1: +class TroubledModelInValidationStep(BoringModel): + def validation_step(self, batch, batch_idx): + if not self.trainer.sanity_checking and batch_idx == 1: raise RuntimeError("Trouble!") -class TroubledModelOnBeforeZeroGrad(BoringModel): - def on_before_zero_grad(self, optimizer): +class TroubledModelBackward(BoringModel): + def backward(self, loss): if self.current_epoch == 1: raise RuntimeError("Trouble!") @@ -873,22 +867,15 @@ def on_after_backward(self): raise RuntimeError("Trouble!") -class TroubledModelOnBeforeOptimizerStep(BoringModel): - def on_before_optimizer_step(self, optimizer): +class TroubledModelOnBeforeZeroGrad(BoringModel): + def on_before_zero_grad(self, optimizer): if self.current_epoch == 1: raise RuntimeError("Trouble!") -class TroubledModelOnTrainBatchEnd(BoringModel): - def on_train_batch_end(self, outputs, batch, batch_idx): - if batch_idx == 1: - raise RuntimeError("Trouble!") - - -class TroubledModelOnTrainEpochEnd(BoringModel): - def on_train_epoch_end(self): - if self.current_epoch == 1: - raise RuntimeError("Trouble!") +class TroubledModelOnFitEnd(BoringModel): + def on_fit_end(self): + raise RuntimeError("Trouble!") class TroubledModelOnTrainEnd(BoringModel): @@ -902,20 +889,38 @@ def on_validation_start(self): raise RuntimeError("Trouble!") -class TroubledModelOnValidationEpochStart(BoringModel): - def on_validation_epoch_start(self): - if not self.trainer.sanity_checking and self.current_epoch == 1: +class TroubledModelOnValidationEnd(BoringModel): + def on_validation_end(self): + if not self.trainer.sanity_checking: raise RuntimeError("Trouble!") -class TroubledModelOnValidationBatchStart(BoringModel): - def on_validation_batch_start(self, batch, batch_idx): - if not self.trainer.sanity_checking and batch_idx == 1: +class TroubledModelOnTrainBatchStart(BoringModel): + def on_train_batch_start(self, batch, batch_idx): + if batch_idx == 1: raise RuntimeError("Trouble!") -class TroubledModelInValidationStep(BoringModel): - def validation_step(self, batch, batch_idx): +class TroubledModelOnTrainBatchEnd(BoringModel): + def on_train_batch_end(self, outputs, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnTrainEpochStart(BoringModel): + def on_train_epoch_start(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnTrainEpochEnd(BoringModel): + def on_train_epoch_end(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationBatchStart(BoringModel): + def on_validation_batch_start(self, batch, batch_idx): if not self.trainer.sanity_checking and batch_idx == 1: raise RuntimeError("Trouble!") @@ -926,44 +931,82 @@ def on_validation_batch_end(self, outputs, batch, batch_idx): raise RuntimeError("Trouble!") +class TroubledModelOnValidationEpochStart(BoringModel): + def on_validation_epoch_start(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + class TroubledModelOnValidationEpochEnd(BoringModel): def on_validation_epoch_end(self): if not self.trainer.sanity_checking and self.current_epoch == 1: raise RuntimeError("Trouble!") -class TroubledModelOnValidationEnd(BoringModel): - def on_validation_end(self): - if not self.trainer.sanity_checking: +class TroubledModelOnValidationModelEval(BoringModel): + def on_validation_model_eval(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: raise RuntimeError("Trouble!") -class TroubledModelOnFitEnd(BoringModel): - def on_fit_end(self): - raise RuntimeError("Trouble!") +class TroubledModelOnValidationModelTrain(BoringModel): + def on_validation_model_train(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnBeforeOptimizerStep(BoringModel): + def on_before_optimizer_step(self, optimizer): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelConfigureGradienClipping(BoringModel): + def configure_gradient_clipping(self, optimizer, gradient_clip_val=None, gradient_clip_algorithm=None): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOptimizerStep(BoringModel): + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure=None): + optimizer.step(closure=optimizer_closure) + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOptimizerZeroGrad(BoringModel): + def optimizer_zero_grad(self, epoch, batch_idx, optimizer): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") @pytest.mark.parametrize( "TroubledModel", [ - TroubledModelOnTrainEpochStart, - TroubledModelOnTrainBatchStart, TroubledModelInTrainingStep, - TroubledModelOnBeforeZeroGrad, + TroubledModelInValidationStep, + TroubledModelBackward, TroubledModelOnBeforeBackward, TroubledModelOnAfterBackward, - TroubledModelOnBeforeOptimizerStep, - TroubledModelOnTrainBatchEnd, - TroubledModelOnTrainEpochEnd, + TroubledModelOnBeforeZeroGrad, + TroubledModelOnFitEnd, TroubledModelOnTrainEnd, TroubledModelOnValidationStart, - TroubledModelOnValidationEpochStart, + TroubledModelOnValidationEnd, + TroubledModelOnTrainBatchStart, + TroubledModelOnTrainBatchEnd, + TroubledModelOnTrainEpochStart, + TroubledModelOnTrainEpochEnd, TroubledModelOnValidationBatchStart, - TroubledModelInValidationStep, TroubledModelOnValidationBatchEnd, + TroubledModelOnValidationEpochStart, TroubledModelOnValidationEpochEnd, - TroubledModelOnValidationEnd, - TroubledModelOnFitEnd, + TroubledModelOnValidationModelEval, + TroubledModelOnValidationModelTrain, + TroubledModelOnBeforeOptimizerStep, + TroubledModelConfigureGradienClipping, + TroubledModelOptimizerStep, + TroubledModelOptimizerZeroGrad, ], ) def test_model_checkpoint_on_exception_parametrized(tmp_path, TroubledModel): From 7dfdbc5addc17bfd883b7ca9a6b257c0ccb08e5a Mon Sep 17 00:00:00 2001 From: vsey Date: Tue, 24 Jun 2025 01:39:46 +0200 Subject: [PATCH 33/37] change datatype of on_exception hook from Exception to BaseException in ModelCheckpoint --- src/lightning/pytorch/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 98ec77aa3bad3..27220ed4ba7f5 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -342,7 +342,7 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self._save_last_checkpoint(trainer, monitor_candidates) @override - def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: Exception) -> None: + def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: """Save a checkpoint when an exception is raised.""" if not self._should_save_on_exception(trainer): return From 048ad8600e1760158d21d2db572908381181aa0c Mon Sep 17 00:00:00 2001 From: vsey Date: Wed, 25 Jun 2025 19:39:25 +0200 Subject: [PATCH 34/37] change default prefix from empty string to Optinal to better convey meaning of unused empty variable in function signutare --- src/lightning/pytorch/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 27220ed4ba7f5..5c2d77af57904 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -563,7 +563,7 @@ def _format_checkpoint_name( self, filename: Optional[str], metrics: dict[str, Tensor], - prefix: str = "", + prefix: Optional[str] = None, auto_insert_metric_name: bool = True, ) -> str: if not filename: @@ -590,7 +590,7 @@ def _format_checkpoint_name( metrics[name] = torch.tensor(0) filename = filename.format(metrics) - if prefix: + if prefix is not None: filename = self.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) return filename From f0cf90b5cd8e2da406b90067d300e9dc786d3d27 Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 10 Jul 2025 11:32:56 +0200 Subject: [PATCH 35/37] Added save_on_exception option for ModelCheckpoint to changelog --- src/lightning/pytorch/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 5b364ac1c7a3e..60598d2a58dd7 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added a `save_on_exception` option to `ModelCheckpoint` Callback ### Changed From 921c9ee104dc247ed5b95c616d1257c56b8830c5 Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 10 Jul 2025 11:35:46 +0200 Subject: [PATCH 36/37] Fix tense of changelog entrie to be in line with rest of changelog --- src/lightning/pytorch/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 60598d2a58dd7..c583508f52169 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added a `save_on_exception` option to `ModelCheckpoint` Callback +- Add a `save_on_exception` option to `ModelCheckpoint` Callback ### Changed From e388ee7a9aadaf95efb09b5995f3c93214732acb Mon Sep 17 00:00:00 2001 From: vsey Date: Thu, 10 Jul 2025 11:39:52 +0200 Subject: [PATCH 37/37] change changelog entrie tense back --- src/lightning/pytorch/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index c583508f52169..f5ea280758092 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Add a `save_on_exception` option to `ModelCheckpoint` Callback +- Added `save_on_exception` option to `ModelCheckpoint` Callback ### Changed