|
| 1 | +import time |
| 2 | +from typing import Any, Optional |
| 3 | + |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | + |
| 7 | +from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO |
| 8 | +from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO |
| 9 | + |
| 10 | + |
| 11 | +class _CaptureCheckpointIO(CheckpointIO): |
| 12 | + def __init__(self) -> None: |
| 13 | + self.saved: Optional[dict[str, Any]] = None |
| 14 | + |
| 15 | + def save_checkpoint(self, checkpoint: dict[str, Any], path: str, storage_options: Optional[Any] = None) -> None: |
| 16 | + # Simulate some delay to increase race window |
| 17 | + time.sleep(0.05) |
| 18 | + # Store the received checkpoint object (not a deep copy) to inspect tensor values |
| 19 | + self.saved = checkpoint |
| 20 | + |
| 21 | + def load_checkpoint(self, path: str, map_location: Optional[Any] = None) -> dict[str, Any]: |
| 22 | + raise NotImplementedError |
| 23 | + |
| 24 | + def remove_checkpoint(self, path: str) -> None: |
| 25 | + pass |
| 26 | + |
| 27 | + |
| 28 | +@pytest.mark.filterwarnings("ignore::DeprecationWarning") |
| 29 | +def test_async_checkpoint_should_snapshot_values_before_mutation(): |
| 30 | + base = _CaptureCheckpointIO() |
| 31 | + async_io = AsyncCheckpointIO(checkpoint_io=base) |
| 32 | + |
| 33 | + # a tensor that we will mutate after scheduling the save |
| 34 | + t = torch.tensor([0.0]) |
| 35 | + ckpt = {"w": t} |
| 36 | + |
| 37 | + # schedule async save |
| 38 | + async_io.save_checkpoint(ckpt, path="unused") |
| 39 | + |
| 40 | + # mutate immediately afterward to mimic training thread stepping params |
| 41 | + t.add_(1.0) |
| 42 | + |
| 43 | + # ensure background thread finished |
| 44 | + async_io.teardown() |
| 45 | + |
| 46 | + assert base.saved is not None, "Async save did not run" |
| 47 | + |
| 48 | + # EXPECTATION: AsyncCheckpointIO should have captured value 0.0 (pre-mutation) |
| 49 | + # CURRENT BEHAVIOR (bug): it captures 1.0 because the dict holds references |
| 50 | + assert torch.allclose(base.saved["w"], torch.tensor([0.0])), ( |
| 51 | + "AsyncCheckpointIO must snapshot the checkpoint (clone tensors) on the main thread " |
| 52 | + "to avoid races with parameter mutation; got mutated value instead" |
| 53 | + ) |
0 commit comments