From 21f9f28dad927c1cfeeab34190a6f06729e9b09d Mon Sep 17 00:00:00 2001 From: default Date: Tue, 1 Jul 2025 16:30:52 +0000 Subject: [PATCH] Make asyncio checkpointing work if validate/fit is called more than once. --- src/lightning/pytorch/CHANGELOG.md | 2 +- .../pytorch/plugins/io/async_plugin.py | 23 ++++++++++++++----- .../plugins/test_checkpoint_io_plugin.py | 4 ++++ 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 5b364ac1c7a3e..39ef22c9694a4 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed `AsyncCheckpointIO` threadpool exception if calling fit or validate more than one. --- diff --git a/src/lightning/pytorch/plugins/io/async_plugin.py b/src/lightning/pytorch/plugins/io/async_plugin.py index 67c02189c541e..0c1e3e55c03cb 100644 --- a/src/lightning/pytorch/plugins/io/async_plugin.py +++ b/src/lightning/pytorch/plugins/io/async_plugin.py @@ -33,14 +33,23 @@ class AsyncCheckpointIO(_WrappingCheckpointIO): def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None: super().__init__(checkpoint_io) - - self._executor = ThreadPoolExecutor(max_workers=1) + self._executor: Optional[ThreadPoolExecutor] = None self._error: Optional[BaseException] = None + # CheckpointIO doesn't have a setup method so we have to do something like. + # We can't do setup in __init__ because if train or validate is called more than once the + # teardown method deletes the executor. + def _ensure_setup(self) -> None: + if self._executor is None: + self._executor = ThreadPoolExecutor(max_workers=1) + self._error: Optional[BaseException] = None + @override def save_checkpoint(self, *args: Any, **kwargs: Any) -> None: """Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``.""" + self._ensure_setup() + def _save_checkpoint(*args: Any, **kwargs: Any) -> None: try: assert self.checkpoint_io is not None @@ -58,8 +67,10 @@ def _save_checkpoint(*args: Any, **kwargs: Any) -> None: @override def teardown(self) -> None: """This method is called to close the threads.""" - self._executor.shutdown(wait=True) + if self._executor is not None: + self._executor.shutdown(wait=True) + self._executor = None - # if an error was raised anytime in any of the `executor.submit` calls - if self._error: - raise self._error + # if an error was raised anytime in any of the `executor.submit` calls + if self._error: + raise self._error diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index 0f62eeae69ef8..f7a76079cfca2 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -127,6 +127,10 @@ def on_fit_start(self): enable_progress_bar=False, enable_model_summary=False, ) + + # We add a validate step to test that async works when fit or validate is called multiple times. + trainer.validate(model) + trainer.fit(model) assert checkpoint_plugin.save_checkpoint.call_count == 3