Skip to content

Commit 2c74bee

Browse files
littlebullGitBorda
andauthored
Fix: AsyncCheckpointIO snapshots tensors to avoid race with parameter mutation (#21079)
* Fix: AsyncCheckpointIO snapshots tensors to avoid race with parameter mutation Summary - Root cause: Background thread serialized live tensor references; the training thread mutated tensors after scheduling the async save, leading to mixed-step checkpoints. - Fix: Snapshot all tensors on the main thread before submitting the async save using `apply_to_collection(..., torch.Tensor, lambda t: t.detach().clone())`. Implementation - Reproduce the issue in unit test - Clone all tensors in the checkpoint payload on the caller thread to take a point-in-time snapshot. - Supports both positional and keyword `checkpoint` parameters. - Preserves non-tensor values; handles nested containers. - Continues to surface background exceptions on teardown. * chlog --------- Co-authored-by: Jirka B <[email protected]>
1 parent 48fa086 commit 2c74bee

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828

2929
### Fixed
3030

31-
-
31+
- Fixed `AsyncCheckpointIO` snapshots tensors to avoid race with parameter mutation ([#21079](https://github.com/Lightning-AI/pytorch-lightning/pull/21079))
3232

3333

3434
- Fixed learning rate not being correctly set after using `LearningRateFinder` callback ([#21068](https://github.com/Lightning-AI/pytorch-lightning/pull/21068))

src/lightning/pytorch/plugins/io/async_plugin.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from concurrent.futures import ThreadPoolExecutor
1616
from typing import Any, Optional
1717

18+
import torch
19+
from lightning_utilities.core.apply_func import apply_to_collection
1820
from typing_extensions import override
1921

2022
from lightning.fabric.plugins import CheckpointIO
@@ -41,6 +43,17 @@ def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
4143
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
4244
"""Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``."""
4345

46+
# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
47+
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
48+
# detach to avoid autograd history and clone to take a point-in-time copy
49+
return t.detach().clone()
50+
51+
# rebuild args/kwargs with a cloned checkpoint (supports positional or kw form)
52+
if "checkpoint" in kwargs:
53+
kwargs = {**kwargs, "checkpoint": apply_to_collection(kwargs["checkpoint"], torch.Tensor, _clone_tensor)}
54+
elif len(args) >= 1:
55+
args = (apply_to_collection(args[0], torch.Tensor, _clone_tensor), *args[1:])
56+
4457
def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
4558
try:
4659
assert self.checkpoint_io is not None
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import time
2+
from typing import Any, Optional
3+
4+
import pytest
5+
import torch
6+
7+
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
8+
from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
9+
10+
11+
class _CaptureCheckpointIO(CheckpointIO):
12+
def __init__(self) -> None:
13+
self.saved: Optional[dict[str, Any]] = None
14+
15+
def save_checkpoint(self, checkpoint: dict[str, Any], path: str, storage_options: Optional[Any] = None) -> None:
16+
# Simulate some delay to increase race window
17+
time.sleep(0.05)
18+
# Store the received checkpoint object (not a deep copy) to inspect tensor values
19+
self.saved = checkpoint
20+
21+
def load_checkpoint(self, path: str, map_location: Optional[Any] = None) -> dict[str, Any]:
22+
raise NotImplementedError
23+
24+
def remove_checkpoint(self, path: str) -> None:
25+
pass
26+
27+
28+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
29+
def test_async_checkpoint_should_snapshot_values_before_mutation():
30+
base = _CaptureCheckpointIO()
31+
async_io = AsyncCheckpointIO(checkpoint_io=base)
32+
33+
# a tensor that we will mutate after scheduling the save
34+
t = torch.tensor([0.0])
35+
ckpt = {"w": t}
36+
37+
# schedule async save
38+
async_io.save_checkpoint(ckpt, path="unused")
39+
40+
# mutate immediately afterward to mimic training thread stepping params
41+
t.add_(1.0)
42+
43+
# ensure background thread finished
44+
async_io.teardown()
45+
46+
assert base.saved is not None, "Async save did not run"
47+
48+
# EXPECTATION: AsyncCheckpointIO should have captured value 0.0 (pre-mutation)
49+
# CURRENT BEHAVIOR (bug): it captures 1.0 because the dict holds references
50+
assert torch.allclose(base.saved["w"], torch.tensor([0.0])), (
51+
"AsyncCheckpointIO must snapshot the checkpoint (clone tensors) on the main thread "
52+
"to avoid races with parameter mutation; got mutated value instead"
53+
)

0 commit comments

Comments
 (0)