Skip to content

Commit 525d9a8

Browse files
authored
Merge branch 'master' into weights-only-compatibility
2 parents 8e0f61e + 2c74bee commit 525d9a8

File tree

4 files changed

+69
-2
lines changed

4 files changed

+69
-2
lines changed

requirements/fabric/test.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ pytest-cov ==6.2.1
55
pytest-timeout ==2.4.0
66
pytest-rerunfailures ==15.1
77
pytest-random-order ==1.2.0
8-
click ==8.1.8
8+
click ==8.1.8; python_version < "3.11"
9+
click ==8.2.1; python_version > "3.10"
910
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828

2929
### Fixed
3030

31-
-
31+
- Fixed `AsyncCheckpointIO` snapshots tensors to avoid race with parameter mutation ([#21079](https://github.com/Lightning-AI/pytorch-lightning/pull/21079))
3232

3333

3434
- Fixed learning rate not being correctly set after using `LearningRateFinder` callback ([#21068](https://github.com/Lightning-AI/pytorch-lightning/pull/21068))

src/lightning/pytorch/plugins/io/async_plugin.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from concurrent.futures import ThreadPoolExecutor
1616
from typing import Any, Optional
1717

18+
import torch
19+
from lightning_utilities.core.apply_func import apply_to_collection
1820
from typing_extensions import override
1921

2022
from lightning.fabric.plugins import CheckpointIO
@@ -41,6 +43,17 @@ def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
4143
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
4244
"""Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``."""
4345

46+
# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
47+
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
48+
# detach to avoid autograd history and clone to take a point-in-time copy
49+
return t.detach().clone()
50+
51+
# rebuild args/kwargs with a cloned checkpoint (supports positional or kw form)
52+
if "checkpoint" in kwargs:
53+
kwargs = {**kwargs, "checkpoint": apply_to_collection(kwargs["checkpoint"], torch.Tensor, _clone_tensor)}
54+
elif len(args) >= 1:
55+
args = (apply_to_collection(args[0], torch.Tensor, _clone_tensor), *args[1:])
56+
4457
def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
4558
try:
4659
assert self.checkpoint_io is not None
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import time
2+
from typing import Any, Optional
3+
4+
import pytest
5+
import torch
6+
7+
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
8+
from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
9+
10+
11+
class _CaptureCheckpointIO(CheckpointIO):
12+
def __init__(self) -> None:
13+
self.saved: Optional[dict[str, Any]] = None
14+
15+
def save_checkpoint(self, checkpoint: dict[str, Any], path: str, storage_options: Optional[Any] = None) -> None:
16+
# Simulate some delay to increase race window
17+
time.sleep(0.05)
18+
# Store the received checkpoint object (not a deep copy) to inspect tensor values
19+
self.saved = checkpoint
20+
21+
def load_checkpoint(self, path: str, map_location: Optional[Any] = None) -> dict[str, Any]:
22+
raise NotImplementedError
23+
24+
def remove_checkpoint(self, path: str) -> None:
25+
pass
26+
27+
28+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
29+
def test_async_checkpoint_should_snapshot_values_before_mutation():
30+
base = _CaptureCheckpointIO()
31+
async_io = AsyncCheckpointIO(checkpoint_io=base)
32+
33+
# a tensor that we will mutate after scheduling the save
34+
t = torch.tensor([0.0])
35+
ckpt = {"w": t}
36+
37+
# schedule async save
38+
async_io.save_checkpoint(ckpt, path="unused")
39+
40+
# mutate immediately afterward to mimic training thread stepping params
41+
t.add_(1.0)
42+
43+
# ensure background thread finished
44+
async_io.teardown()
45+
46+
assert base.saved is not None, "Async save did not run"
47+
48+
# EXPECTATION: AsyncCheckpointIO should have captured value 0.0 (pre-mutation)
49+
# CURRENT BEHAVIOR (bug): it captures 1.0 because the dict holds references
50+
assert torch.allclose(base.saved["w"], torch.tensor([0.0])), (
51+
"AsyncCheckpointIO must snapshot the checkpoint (clone tensors) on the main thread "
52+
"to avoid races with parameter mutation; got mutated value instead"
53+
)

0 commit comments

Comments
 (0)