Skip to content

Make asyncio checkpointing work if validate/fit is called more than once #20952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed `AsyncCheckpointIO` threadpool exception if calling fit or validate more than one.


---
Expand Down
23 changes: 17 additions & 6 deletions src/lightning/pytorch/plugins/io/async_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,23 @@ class AsyncCheckpointIO(_WrappingCheckpointIO):

def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
super().__init__(checkpoint_io)

self._executor = ThreadPoolExecutor(max_workers=1)
self._executor: Optional[ThreadPoolExecutor] = None
self._error: Optional[BaseException] = None

# CheckpointIO doesn't have a setup method so we have to do something like.
# We can't do setup in __init__ because if train or validate is called more than once the
# teardown method deletes the executor.
def _ensure_setup(self) -> None:
if self._executor is None:
self._executor = ThreadPoolExecutor(max_workers=1)
self._error: Optional[BaseException] = None

@override
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
"""Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``."""

self._ensure_setup()

def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
try:
assert self.checkpoint_io is not None
Expand All @@ -58,8 +67,10 @@ def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
@override
def teardown(self) -> None:
"""This method is called to close the threads."""
self._executor.shutdown(wait=True)
if self._executor is not None:
self._executor.shutdown(wait=True)
self._executor = None

# if an error was raised anytime in any of the `executor.submit` calls
if self._error:
raise self._error
# if an error was raised anytime in any of the `executor.submit` calls
if self._error:
raise self._error
4 changes: 4 additions & 0 deletions tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def on_fit_start(self):
enable_progress_bar=False,
enable_model_summary=False,
)

# We add a validate step to test that async works when fit or validate is called multiple times.
trainer.validate(model)

trainer.fit(model)

assert checkpoint_plugin.save_checkpoint.call_count == 3
Expand Down
Loading