From 074b01ebe96709eed1f63907be7aaed3b34c3a87 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Thu, 14 Aug 2025 14:22:18 -0400 Subject: [PATCH 01/45] change weights_only default to True --- src/lightning/fabric/utilities/cloud_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 637dfcd9b1671..16eca9943aac1 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -34,7 +34,7 @@ def _load( path_or_url: Union[IO, _PATH], map_location: _MAP_LOCATION_TYPE = None, - weights_only: bool = False, + weights_only: bool = True, ) -> Any: """Loads a checkpoint. @@ -70,7 +70,7 @@ def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem: return fs -def _atomic_save(checkpoint: dict[str, Any], filepath: Union[str, Path]) -> None: +def _atomic_save(checkpoint: dict[str, Any], filepath: _PATH) -> None: """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. Args: From 65cc1ed9360366259e671eebeb8308f5b19f16e5 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Thu, 14 Aug 2025 14:54:41 -0400 Subject: [PATCH 02/45] add docs on weights_only arg --- src/lightning/fabric/utilities/cloud_io.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 16eca9943aac1..23d45f84dcae1 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -41,6 +41,11 @@ def _load( Args: path_or_url: Path or URL of the checkpoint. map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other primitive + types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `__. """ if not isinstance(path_or_url, (str, Path)): From f276114e0c90f7396e4bd5c2b5b724d83cebe00b Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Fri, 15 Aug 2025 18:01:57 -0400 Subject: [PATCH 03/45] add weights_only arg to checkpoint save. weights_only during test set based on ckpt version --- src/lightning/pytorch/core/saving.py | 8 +++++++- .../trainer/connectors/checkpoint_connector.py | 4 +++- .../checkpointing/test_legacy_checkpoints.py | 15 +++++++++++++-- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 21fd3912f7849..81e766751ea15 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -56,11 +56,17 @@ def _load_from_checkpoint( map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, strict: Optional[bool] = None, + weights_only: Optional[bool] = None, **kwargs: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: map_location = map_location or _default_map_location + + if weights_only is None: + log.debug("`weights_only` not specified, defaulting to `True`.") + weights_only = True + with pl_legacy_patch(): - checkpoint = pl_load(checkpoint_path, map_location=map_location) + checkpoint = pl_load(checkpoint_path, map_location=map_location, weights_only=weights_only) # convert legacy checkpoints to the new format checkpoint = _pl_migrate_checkpoint( diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 7f97a2f54bf19..52fb0e3230a82 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -414,7 +414,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: """Creating a model checkpoint dictionary object from various component states. Args: - weights_only: saving model weights only + weights_only: If True, only saves model and loops state_dict objects. If False, + additionally saves callbacks, optimizers, schedulers, and precision plugin states. + Return: structured dictionary: { 'epoch': training epoch diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 006a123356c98..bed597c31b885 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -18,6 +18,7 @@ import pytest import torch +from packaging.version import Version import lightning.pytorch as pl from lightning.pytorch import Callback, Trainer @@ -45,7 +46,12 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str): assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] - model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24) + # legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166) + if pl_version == "local": + pl_version = pl.__version__ + weights_only = not Version(pl_version) < Version("1.5.0") + + model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24, weights_only=weights_only) trainer = Trainer(default_root_dir=tmp_path) dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8) res = trainer.test(model, datamodule=dm) @@ -73,13 +79,18 @@ def test_legacy_ckpt_threading(pl_version: str): assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] + # legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166) + if pl_version == "local": + pl_version = pl.__version__ + weights_only = not Version(pl_version) < Version("1.5.0") + def load_model(): import torch from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): - _ = torch.load(path_ckpt, weights_only=False) + _ = torch.load(path_ckpt, weights_only=weights_only) with patch("sys.path", [PATH_LEGACY] + sys.path): t1 = ThreadExceptionHandler(target=load_model) From 28f53ae9e3f7056e9e5e9945d45a55fc04bb99cc Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Fri, 15 Aug 2025 20:35:25 -0400 Subject: [PATCH 04/45] add weights_only arg to checkpoint_io --- src/lightning/fabric/plugins/io/checkpoint_io.py | 9 ++++++++- src/lightning/fabric/plugins/io/torch_io.py | 4 ++-- tests/legacy/generate_checkpoints.sh | 10 ++++++---- .../tests_pytorch/plugins/test_checkpoint_io_plugin.py | 4 +++- 4 files changed, 19 insertions(+), 8 deletions(-) mode change 100644 => 100755 tests/legacy/generate_checkpoints.sh diff --git a/src/lightning/fabric/plugins/io/checkpoint_io.py b/src/lightning/fabric/plugins/io/checkpoint_io.py index 3a33dac3335d1..102af722a09f0 100644 --- a/src/lightning/fabric/plugins/io/checkpoint_io.py +++ b/src/lightning/fabric/plugins/io/checkpoint_io.py @@ -47,13 +47,20 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio """ @abstractmethod - def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> dict[str, Any]: + def load_checkpoint( + self, path: _PATH, map_location: Optional[Any] = None, weights_only: bool = True + ) -> dict[str, Any]: """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other + primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `__. Returns: The loaded checkpoint. diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py index 90a5f62ba7413..c6ac50b71de61 100644 --- a/src/lightning/fabric/plugins/io/torch_io.py +++ b/src/lightning/fabric/plugins/io/torch_io.py @@ -59,7 +59,7 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio @override def load_checkpoint( - self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage + self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage, weights_only: bool = True ) -> dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. @@ -80,7 +80,7 @@ def load_checkpoint( if not fs.exists(path): raise FileNotFoundError(f"Checkpoint file not found: {path}") - return pl_load(path, map_location=map_location) + return pl_load(path, map_location=map_location, weights_only=weights_only) @override def remove_checkpoint(self, path: _PATH) -> None: diff --git a/tests/legacy/generate_checkpoints.sh b/tests/legacy/generate_checkpoints.sh old mode 100644 new mode 100755 index 1d083a2a8e052..b342d0801d49e --- a/tests/legacy/generate_checkpoints.sh +++ b/tests/legacy/generate_checkpoints.sh @@ -16,9 +16,9 @@ printf "PYTHONPATH: $PYTHONPATH" rm -rf $ENV_PATH function create_and_save_checkpoint { - python --version - python -m pip --version - python -m pip list + # python --version + # python -m pip --version + # python -m pip list python $LEGACY_FOLDER/simple_classif_training.py $pl_ver @@ -52,10 +52,12 @@ done if [[ -z "$@" ]]; then printf "\n\n processing local version\n" - python -m pip install \ + # python -m pip install \ + uv pip install \ -r $LEGACY_FOLDER/requirements.txt \ -r "$(dirname $TESTS_FOLDER)/requirements/pytorch/test.txt" \ -f https://download.pytorch.org/whl/cpu/torch_stable.html pl_ver="local" + # pl_ver=$(python -c "import lightning.pytorch as pl; print(pl.__version__)") create_and_save_checkpoint fi diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index 0f62eeae69ef8..de1e20b679a39 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -32,7 +32,9 @@ class CustomCheckpointIO(CheckpointIO): def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: torch.save(checkpoint, path) - def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> dict[str, Any]: + def load_checkpoint( + self, path: _PATH, storage_options: Optional[Any] = None, weights_only: bool = True + ) -> dict[str, Any]: return torch.load(path, weights_only=True) def remove_checkpoint(self, path: _PATH) -> None: From b1cfdf1f82a097e987c6f2f0a7c02571014681d9 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Fri, 15 Aug 2025 20:48:15 -0400 Subject: [PATCH 05/45] woops, reverting changes --- tests/legacy/generate_checkpoints.sh | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/legacy/generate_checkpoints.sh b/tests/legacy/generate_checkpoints.sh index b342d0801d49e..1d083a2a8e052 100755 --- a/tests/legacy/generate_checkpoints.sh +++ b/tests/legacy/generate_checkpoints.sh @@ -16,9 +16,9 @@ printf "PYTHONPATH: $PYTHONPATH" rm -rf $ENV_PATH function create_and_save_checkpoint { - # python --version - # python -m pip --version - # python -m pip list + python --version + python -m pip --version + python -m pip list python $LEGACY_FOLDER/simple_classif_training.py $pl_ver @@ -52,12 +52,10 @@ done if [[ -z "$@" ]]; then printf "\n\n processing local version\n" - # python -m pip install \ - uv pip install \ + python -m pip install \ -r $LEGACY_FOLDER/requirements.txt \ -r "$(dirname $TESTS_FOLDER)/requirements/pytorch/test.txt" \ -f https://download.pytorch.org/whl/cpu/torch_stable.html pl_ver="local" - # pl_ver=$(python -c "import lightning.pytorch as pl; print(pl.__version__)") create_and_save_checkpoint fi From 4d96a78bbab93bd863502b10bf6e0f69654619dd Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Fri, 15 Aug 2025 20:49:17 -0400 Subject: [PATCH 06/45] permissions too --- tests/legacy/generate_checkpoints.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 tests/legacy/generate_checkpoints.sh diff --git a/tests/legacy/generate_checkpoints.sh b/tests/legacy/generate_checkpoints.sh old mode 100755 new mode 100644 From 4c39c30a3a27563fee9653251cd7363d800b702f Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Fri, 15 Aug 2025 23:24:24 -0400 Subject: [PATCH 07/45] fix link --- src/lightning/fabric/plugins/io/checkpoint_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/plugins/io/checkpoint_io.py b/src/lightning/fabric/plugins/io/checkpoint_io.py index 102af722a09f0..ee625652fdf6d 100644 --- a/src/lightning/fabric/plugins/io/checkpoint_io.py +++ b/src/lightning/fabric/plugins/io/checkpoint_io.py @@ -60,7 +60,7 @@ def load_checkpoint( primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using ``weights_only=True``. For more information, please refer to the - `PyTorch Developer Notes on Serialization Semantics `__. + `PyTorch Developer Notes on Serialization Semantics `_. Returns: The loaded checkpoint. From 861d7e0c103b4040654448d23f59f3a3d0f93537 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Sat, 16 Aug 2025 13:04:35 -0400 Subject: [PATCH 08/45] fix another link --- src/lightning/fabric/utilities/cloud_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 23d45f84dcae1..d9ea6f98b24a1 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -45,7 +45,7 @@ def _load( types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using ``weights_only=True``. For more information, please refer to the - `PyTorch Developer Notes on Serialization Semantics `__. + `PyTorch Developer Notes on Serialization Semantics `_. """ if not isinstance(path_or_url, (str, Path)): From 12bd0d6cc39ca44d637da89508db79f9118f6ec0 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Sun, 17 Aug 2025 18:46:20 -0400 Subject: [PATCH 09/45] datamodule weights_only args --- src/lightning/pytorch/core/datamodule.py | 7 +++++++ src/lightning/pytorch/core/module.py | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index ff84c2fd8b199..07ec02ef87bd8 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -177,6 +177,7 @@ def load_from_checkpoint( checkpoint_path: Union[_PATH, IO], map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, + weights_only: Optional[bool] = None, **kwargs: Any, ) -> Self: r"""Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the @@ -206,6 +207,11 @@ def load_from_checkpoint( If your datamodule's ``hparams`` argument is :class:`~argparse.Namespace` and ``.yaml`` file has hierarchical structure, you need to refactor your datamodule to treat ``hparams`` as :class:`~dict`. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other + primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. \**kwargs: Any extra keyword args needed to init the datamodule. Can also be used to override saved hyperparameter values. @@ -242,6 +248,7 @@ def load_from_checkpoint( map_location=map_location, hparams_file=hparams_file, strict=None, + weights_only=weights_only, **kwargs, ) return cast(Self, loaded) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 85cb032d3f80b..37b07f025f8e9 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1724,6 +1724,11 @@ def load_from_checkpoint( strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys returned by this module's state dict. Defaults to ``True`` unless ``LightningModule.strict_loading`` is set, in which case it defaults to the value of ``LightningModule.strict_loading``. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other + primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. \**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values. From 5eacb6ea3037476559c482c1f76c79fdc6dd9f85 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Sun, 17 Aug 2025 18:46:54 -0400 Subject: [PATCH 10/45] wip: try safe_globals context manager for tests --- tests/tests_pytorch/models/test_hparams.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 575bcadadc404..2fb555022879a 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -109,6 +109,9 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False) + # with torch.serialization.safe_globals([Container, DictConfig]): + # raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=True) + assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 @@ -175,8 +178,10 @@ def test_omega_conf_hparams(tmp_path, cls): assert isinstance(obj.hparams, Container) # run standard test suite + # with torch.serialization.safe_globals([Container, DictConfig]): raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule) obj2 = cls.load_from_checkpoint(raw_checkpoint_path) + assert isinstance(obj2.hparams, Container) # config specific tests From 0430e226e40da63fa0a161cb9e9a858197532a50 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Sun, 17 Aug 2025 21:49:54 -0400 Subject: [PATCH 11/45] add weights_only arg to _run_standard_hparams_test --- tests/tests_pytorch/models/test_hparams.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 2fb555022879a..a173cec4e0394 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -94,7 +94,7 @@ def __init__(self, hparams, *my_args, **my_kwargs): # ------------------------- # STANDARD TESTS # ------------------------- -def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False): +def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False, weights_only=True): """Tests for the existence of an arg 'test_arg=14'.""" obj = datamodule if issubclass(cls, LightningDataModule) else model @@ -108,22 +108,20 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) - raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False) - # with torch.serialization.safe_globals([Container, DictConfig]): - # raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=True) + raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only) assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 # verify that model loads correctly - obj2 = cls.load_from_checkpoint(raw_checkpoint_path) + obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only) assert obj2.hparams.test_arg == 14 assert isinstance(obj2.hparams, hparam_type) if try_overwrite: # verify that we can overwrite the property - obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78) + obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78, weights_only=weights_only) assert obj3.hparams.test_arg == 78 return raw_checkpoint_path @@ -178,9 +176,8 @@ def test_omega_conf_hparams(tmp_path, cls): assert isinstance(obj.hparams, Container) # run standard test suite - # with torch.serialization.safe_globals([Container, DictConfig]): - raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule) - obj2 = cls.load_from_checkpoint(raw_checkpoint_path) + raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule, weights_only=False) + obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=False) assert isinstance(obj2.hparams, Container) From 2abe91598a764692ec203bb40a7808073f3890e9 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Sun, 17 Aug 2025 22:30:19 -0400 Subject: [PATCH 12/45] weights_only=False when adding extra_args --- tests/tests_pytorch/models/test_hparams.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index a173cec4e0394..fa189938759fb 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -369,13 +369,17 @@ class DictConfSubClassBoringModel: ... BoringModelWithMixinAndInit, ], ) -def test_collect_init_arguments(tmp_path, cls): +def test_collect_init_arguments(tmp_path, cls: BoringModel): """Test that the model automatically saves the arguments passed into the constructor.""" extra_args = {} + weights_only = True + if cls is AggSubClassBoringModel: extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss()) + weights_only = False elif cls is DictConfSubClassBoringModel: extra_args.update(dict_conf=OmegaConf.create({"my_param": "anything"})) + weights_only = False model = cls(**extra_args) assert model.hparams.batch_size == 64 @@ -394,12 +398,12 @@ def test_collect_init_arguments(tmp_path, cls): raw_checkpoint_path = _raw_checkpoint_path(trainer) - raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False) + raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only) assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["batch_size"] == 179 # verify that model loads correctly - model = cls.load_from_checkpoint(raw_checkpoint_path) + model = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only) assert model.hparams.batch_size == 179 if isinstance(model, AggSubClassBoringModel): @@ -410,7 +414,7 @@ def test_collect_init_arguments(tmp_path, cls): assert model.hparams.dict_conf["my_param"] == "anything" # verify that we can overwrite whatever we want - model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99) + model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99, weights_only=weights_only) assert model.hparams.batch_size == 99 From 83fd824d8bd1a67db4cbfb9ff3e0173ba164099b Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 19 Aug 2025 13:15:58 +0200 Subject: [PATCH 13/45] switch to lightning_utilities.cli requirements set-oldest (#21077) --- .actions/assistant.py | 14 -------------- .azure/gpu-tests-fabric.yml | 4 +++- .azure/gpu-tests-pytorch.yml | 4 +++- .github/workflows/ci-tests-fabric.yml | 4 +++- .github/workflows/ci-tests-pytorch.yml | 4 +++- .lightning/workflows/fabric.yml | 5 ++++- .lightning/workflows/pytorch.yml | 5 ++++- 7 files changed, 20 insertions(+), 20 deletions(-) diff --git a/.actions/assistant.py b/.actions/assistant.py index 7b2d49423d622..0109a988d5692 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -368,20 +368,6 @@ def _prune_packages(req_file: str, packages: Sequence[str]) -> None: print(final) path.write_text("\n".join(final) + "\n") - @staticmethod - def _replace_min(fname: str) -> None: - with open(fname, encoding="utf-8") as fopen: - req = fopen.read().replace(">=", "==") - with open(fname, "w", encoding="utf-8") as fwrite: - fwrite.write(req) - - @staticmethod - def replace_oldest_ver(requirement_fnames: Sequence[str] = REQUIREMENT_FILES_ALL) -> None: - """Replace the min package version by fixed one.""" - for fname in requirement_fnames: - print(fname) - AssistantCLI._replace_min(fname) - @staticmethod def copy_replace_imports( source_dir: str, diff --git a/.azure/gpu-tests-fabric.yml b/.azure/gpu-tests-fabric.yml index c2d492e5b9564..f506aa2008df9 100644 --- a/.azure/gpu-tests-fabric.yml +++ b/.azure/gpu-tests-fabric.yml @@ -99,7 +99,9 @@ jobs: displayName: "Image info & NVIDIA" - bash: | - python .actions/assistant.py replace_oldest_ver + cd requirements/fabric + pip install -U "lightning-utilities[cli]" + python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'strategies.txt']" pip install "cython<3.0" wheel # for compatibility condition: contains(variables['Agent.JobName'], 'oldest') displayName: "setting oldest dependencies" diff --git a/.azure/gpu-tests-pytorch.yml b/.azure/gpu-tests-pytorch.yml index 16ced045ddade..68e99f2f6285a 100644 --- a/.azure/gpu-tests-pytorch.yml +++ b/.azure/gpu-tests-pytorch.yml @@ -103,7 +103,9 @@ jobs: displayName: "Image info & NVIDIA" - bash: | - python .actions/assistant.py replace_oldest_ver + cd requirements/pytorch + pip install -U "lightning-utilities[cli]" + python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'extra.txt', 'strategies.txt', 'examples.txt']" pip install "cython<3.0" wheel # for compatibility condition: contains(variables['Agent.JobName'], 'oldest') displayName: "setting oldest dependencies" diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index ed2301d878a10..888de9bdb4a09 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -94,7 +94,9 @@ jobs: - name: Set min. dependencies if: ${{ matrix.requires == 'oldest' }} run: | - python .actions/assistant.py replace_oldest_ver + cd requirements/fabric + pip install -U "lightning-utilities[cli]" + python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'strategies.txt']" pip install "cython<3.0" wheel pip install "pyyaml==5.4" --no-build-isolation diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml index 49ff08df2eab1..e527961590c3a 100644 --- a/.github/workflows/ci-tests-pytorch.yml +++ b/.github/workflows/ci-tests-pytorch.yml @@ -99,7 +99,9 @@ jobs: - name: Set min. dependencies if: ${{ matrix.requires == 'oldest' }} run: | - python .actions/assistant.py replace_oldest_ver + cd requirements/pytorch + pip install -U "lightning-utilities[cli]" + python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'extra.txt', 'strategies.txt', 'examples.txt']" pip install "cython<3.0" wheel pip install "pyyaml==5.4" --no-build-isolation diff --git a/.lightning/workflows/fabric.yml b/.lightning/workflows/fabric.yml index 265de22b714fb..438f56ef7fe94 100644 --- a/.lightning/workflows/fabric.yml +++ b/.lightning/workflows/fabric.yml @@ -43,7 +43,10 @@ run: | if [ "${TORCH_VER}" == "2.1" ]; then echo "Set oldest versions" - python .actions/assistant.py replace_oldest_ver + cd requirements/fabric + pip install -U "lightning-utilities[cli]" + python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'strategies.txt']" + cd ../.. pip install "cython<3.0" wheel # for compatibility fi diff --git a/.lightning/workflows/pytorch.yml b/.lightning/workflows/pytorch.yml index b87ab98a8c0cd..5c92bf881d969 100644 --- a/.lightning/workflows/pytorch.yml +++ b/.lightning/workflows/pytorch.yml @@ -43,7 +43,10 @@ run: | if [ "${TORCH_VER}" == "2.1" ]; then recho "Set oldest versions" - python .actions/assistant.py replace_oldest_ver + cd requirements/pytorch + pip install -U "lightning-utilities[cli]" + python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'extra.txt', 'strategies.txt', 'examples.txt']" + cd ../.. pip install "cython<3.0" wheel # for compatibility fi From 93cbe940fad84f0bf7d99cdfa08e505ef94bb33d Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 19 Aug 2025 15:13:19 +0200 Subject: [PATCH 14/45] bump: try `deepspeed >=0.14.1,<=0.15.0` (#21076) * try `deepspeed >=0.14.1,<=0.15.0` * drop from oldest * pip uninstall -y deepspeed * error::DeprecationWarning --- .actions/assistant.py | 27 ------------------- .azure/gpu-tests-fabric.yml | 10 +++++-- .azure/gpu-tests-pytorch.yml | 10 +++++-- dockers/release/Dockerfile | 10 ++++++- pyproject.toml | 1 + requirements/fabric/strategies.txt | 2 +- requirements/pytorch/strategies.txt | 2 +- src/lightning/fabric/strategies/deepspeed.py | 6 +---- .../test_multiprocessing_integration.py | 1 + tests/tests_pytorch/utilities/test_compile.py | 2 +- 10 files changed, 31 insertions(+), 40 deletions(-) diff --git a/.actions/assistant.py b/.actions/assistant.py index 0109a988d5692..e54e69e4860e7 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -341,33 +341,6 @@ def create_mirror_package(source_dir: str, package_mapping: dict[str, str]) -> N class AssistantCLI: - @staticmethod - def requirements_prune_pkgs(packages: Sequence[str], req_files: Sequence[str] = REQUIREMENT_FILES_ALL) -> None: - """Remove some packages from given requirement files.""" - if isinstance(req_files, str): - req_files = [req_files] - for req in req_files: - AssistantCLI._prune_packages(req, packages) - - @staticmethod - def _prune_packages(req_file: str, packages: Sequence[str]) -> None: - """Remove some packages from given requirement files.""" - path = Path(req_file) - assert path.exists() - text = path.read_text() - lines = text.splitlines() - final = [] - for line in lines: - ln_ = line.strip() - if not ln_ or ln_.startswith("#"): - final.append(line) - continue - req = list(_parse_requirements([ln_]))[0] - if req.name not in packages: - final.append(line) - print(final) - path.write_text("\n".join(final) + "\n") - @staticmethod def copy_replace_imports( source_dir: str, diff --git a/.azure/gpu-tests-fabric.yml b/.azure/gpu-tests-fabric.yml index f506aa2008df9..b2f8ab0447a20 100644 --- a/.azure/gpu-tests-fabric.yml +++ b/.azure/gpu-tests-fabric.yml @@ -99,10 +99,16 @@ jobs: displayName: "Image info & NVIDIA" - bash: | - cd requirements/fabric + set -ex + pip install "cython<3.0" wheel # for compatibility pip install -U "lightning-utilities[cli]" + cd requirements/fabric + # replace range by pin minimal requirements python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'strategies.txt']" - pip install "cython<3.0" wheel # for compatibility + # drop deepspeed since it is not supported by our minimal Torch requirements + python -m lightning_utilities.cli requirements prune-pkgs --packages deepspeed --req_files strategies.txt + # uninstall deepspeed since some older docker images have it pre-installed + pip uninstall -y deepspeed condition: contains(variables['Agent.JobName'], 'oldest') displayName: "setting oldest dependencies" diff --git a/.azure/gpu-tests-pytorch.yml b/.azure/gpu-tests-pytorch.yml index 68e99f2f6285a..d3c4951a22336 100644 --- a/.azure/gpu-tests-pytorch.yml +++ b/.azure/gpu-tests-pytorch.yml @@ -103,10 +103,16 @@ jobs: displayName: "Image info & NVIDIA" - bash: | - cd requirements/pytorch + set -ex + pip install "cython<3.0" wheel # for compatibility pip install -U "lightning-utilities[cli]" + cd requirements/pytorch + # replace range by pin minimal requirements python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'extra.txt', 'strategies.txt', 'examples.txt']" - pip install "cython<3.0" wheel # for compatibility + # drop deepspeed since it is not supported by our minimal Torch requirements + python -m lightning_utilities.cli requirements prune-pkgs --packages deepspeed --req_files strategies.txt + # uninstall deepspeed since some older docker images have it pre-installed + pip uninstall -y deepspeed condition: contains(variables['Agent.JobName'], 'oldest') displayName: "setting oldest dependencies" diff --git a/dockers/release/Dockerfile b/dockers/release/Dockerfile index c7a7a1e0c8470..b80c23dfc73f3 100644 --- a/dockers/release/Dockerfile +++ b/dockers/release/Dockerfile @@ -21,6 +21,7 @@ FROM pytorchlightning/pytorch_lightning:base-cuda${CUDA_VERSION}-py${PYTHON_VERS LABEL maintainer="Lightning-AI " ARG LIGHTNING_VERSION="" +ARG PYTORCH_VERSION COPY ./ /home/pytorch-lightning/ @@ -39,7 +40,14 @@ RUN \ fi && \ # otherwise there is collision with folder name and pkg name on Pypi cd pytorch-lightning && \ - pip install setuptools==75.6.0 && \ + # pip install setuptools==75.6.0 && \ + pip install -U "lightning-utilities[cli]" && \ + # drop deepspeed since it is not supported by our minimal Torch requirements \ + echo "PYTORCH_VERSION is: '$PYTORCH_VERSION'" && \ + if [[ "$PYTORCH_VERSION" =~ ^(2\.1|2\.2|2\.3|2\.4)$ ]]; then \ + python -m lightning_utilities.cli requirements prune-pkgs --packages deepspeed --req_files requirements/fabric/strategies.txt ; \ + python -m lightning_utilities.cli requirements prune-pkgs --packages deepspeed --req_files requirements/pytorch/strategies.txt ; \ + fi && \ PACKAGE_NAME=lightning pip install '.[extra,loggers,strategies]' --no-cache-dir && \ PACKAGE_NAME=pytorch pip install '.[extra,loggers,strategies]' --no-cache-dir && \ cd .. && \ diff --git a/pyproject.toml b/pyproject.toml index a63da5f246392..b4d5d0b1638f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -179,6 +179,7 @@ markers = [ "cloud: Run the cloud tests for example", ] filterwarnings = [ + # "error::DeprecationWarning", "error::FutureWarning", "ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated ] diff --git a/requirements/fabric/strategies.txt b/requirements/fabric/strategies.txt index bea30b37fa5f8..7856db1df2eec 100644 --- a/requirements/fabric/strategies.txt +++ b/requirements/fabric/strategies.txt @@ -5,5 +5,5 @@ # note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods` # shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372 -deepspeed >=0.9.3, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict +deepspeed >=0.14.1,<=0.15.0; platform_system != "Windows" and platform_system != "Darwin" # strict bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin" diff --git a/requirements/pytorch/strategies.txt b/requirements/pytorch/strategies.txt index 1f7296798b551..89392d6006d38 100644 --- a/requirements/pytorch/strategies.txt +++ b/requirements/pytorch/strategies.txt @@ -3,4 +3,4 @@ # note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods` # shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372 -deepspeed >=0.9.3, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict +deepspeed >=0.14.1,<=0.15.0; platform_system != "Windows" and platform_system != "Darwin" # strict diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 48333455240cf..c11ae8589d1ff 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -47,7 +47,6 @@ from torch.optim.lr_scheduler import _LRScheduler _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") -_DEEPSPEED_GREATER_EQUAL_0_14_1 = RequirementCache("deepspeed>=0.14.1") # TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced. @@ -503,10 +502,7 @@ def load_checkpoint( ) engine = engines[0] - if _DEEPSPEED_GREATER_EQUAL_0_14_1: - from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer - else: - from deepspeed.runtime import DeepSpeedOptimizer + from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values()) diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py index 2eaf1d23572c8..85688ef8fb489 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py @@ -30,6 +30,7 @@ def __init__(self): @RunIf(skip_windows=True) +@pytest.mark.flaky(reruns=3) @pytest.mark.parametrize("strategy", ["ddp_spawn", "ddp_fork"]) def test_memory_sharing_disabled(strategy): """Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race diff --git a/tests/tests_pytorch/utilities/test_compile.py b/tests/tests_pytorch/utilities/test_compile.py index a053c847dfd6c..f90cd5e3ef3fa 100644 --- a/tests/tests_pytorch/utilities/test_compile.py +++ b/tests/tests_pytorch/utilities/test_compile.py @@ -32,7 +32,7 @@ # https://github.com/pytorch/pytorch/issues/95708 @pytest.mark.skipif(sys.platform == "darwin", reason="fatal error: 'omp.h' file not found") -@RunIf(dynamo=True) +@RunIf(dynamo=True, deepspeed=True) @mock.patch("lightning.pytorch.trainer.call._call_and_handle_interrupt") def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0): trainer_kwargs = { From 2a53f2fc70e63d64dff393a3d00e3d155f501044 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Tue, 19 Aug 2025 11:17:32 -0400 Subject: [PATCH 15/45] weights_only=True default for torch>=2.6 --- src/lightning/fabric/plugins/io/checkpoint_io.py | 10 +++++----- src/lightning/fabric/plugins/io/torch_io.py | 10 +++++++++- src/lightning/fabric/utilities/cloud_io.py | 10 ++++++++-- src/lightning/fabric/utilities/imports.py | 1 + src/lightning/pytorch/core/saving.py | 4 ---- .../checkpointing/test_legacy_checkpoints.py | 7 +------ tests/tests_pytorch/models/test_hparams.py | 10 +++++----- 7 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/lightning/fabric/plugins/io/checkpoint_io.py b/src/lightning/fabric/plugins/io/checkpoint_io.py index ee625652fdf6d..db7578d9ca8c0 100644 --- a/src/lightning/fabric/plugins/io/checkpoint_io.py +++ b/src/lightning/fabric/plugins/io/checkpoint_io.py @@ -48,7 +48,7 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio @abstractmethod def load_checkpoint( - self, path: _PATH, map_location: Optional[Any] = None, weights_only: bool = True + self, path: _PATH, map_location: Optional[Any] = None, weights_only: Optional[bool] = None ) -> dict[str, Any]: """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. @@ -56,10 +56,10 @@ def load_checkpoint( path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. - weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other - primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use - ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using - ``weights_only=True``. For more information, please refer to the + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the `PyTorch Developer Notes on Serialization Semantics `_. Returns: The loaded checkpoint. diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py index c6ac50b71de61..c52ad6913e1e2 100644 --- a/src/lightning/fabric/plugins/io/torch_io.py +++ b/src/lightning/fabric/plugins/io/torch_io.py @@ -59,7 +59,10 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio @override def load_checkpoint( - self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage, weights_only: bool = True + self, + path: _PATH, + map_location: Optional[Callable] = lambda storage, loc: storage, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. @@ -67,6 +70,11 @@ def load_checkpoint( path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. Returns: The loaded checkpoint. diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index d9ea6f98b24a1..8506ddcdf9f19 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -17,7 +17,7 @@ import io import logging from pathlib import Path -from typing import IO, Any, Union +from typing import IO, Any, Optional, Union import fsspec import fsspec.utils @@ -26,6 +26,7 @@ from fsspec.implementations.local import AbstractFileSystem from lightning_utilities.core.imports import module_available +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH log = logging.getLogger(__name__) @@ -34,7 +35,7 @@ def _load( path_or_url: Union[IO, _PATH], map_location: _MAP_LOCATION_TYPE = None, - weights_only: bool = True, + weights_only: Optional[bool] = None, ) -> Any: """Loads a checkpoint. @@ -48,6 +49,11 @@ def _load( `PyTorch Developer Notes on Serialization Semantics `_. """ + # default to `weights_only=True` for torch>=2.6 + if weights_only is None and _TORCH_GREATER_EQUAL_2_6: + log.debug("Defaulting to `weights_only=True` for torch>=2.6.") + weights_only = True + if not isinstance(path_or_url, (str, Path)): # any sort of BytesIO or similar return torch.load( diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 70239baac0e6d..5655f2674638e 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -35,6 +35,7 @@ _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") _TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") _TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0") +_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0") _TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0") _TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 81e766751ea15..391e9dd5d0f25 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -61,10 +61,6 @@ def _load_from_checkpoint( ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: map_location = map_location or _default_map_location - if weights_only is None: - log.debug("`weights_only` not specified, defaulting to `True`.") - weights_only = True - with pl_legacy_patch(): checkpoint = pl_load(checkpoint_path, map_location=map_location, weights_only=weights_only) diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index bed597c31b885..61816c1a12877 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -46,12 +46,7 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str): assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] - # legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166) - if pl_version == "local": - pl_version = pl.__version__ - weights_only = not Version(pl_version) < Version("1.5.0") - - model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24, weights_only=weights_only) + model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24) trainer = Trainer(default_root_dir=tmp_path) dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8) res = trainer.test(model, datamodule=dm) diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index fa189938759fb..da5d75273453d 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -94,7 +94,7 @@ def __init__(self, hparams, *my_args, **my_kwargs): # ------------------------- # STANDARD TESTS # ------------------------- -def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False, weights_only=True): +def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False): """Tests for the existence of an arg 'test_arg=14'.""" obj = datamodule if issubclass(cls, LightningDataModule) else model @@ -108,20 +108,20 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) - raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only) + raw_checkpoint = torch.load(raw_checkpoint_path) assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 # verify that model loads correctly - obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only) + obj2 = cls.load_from_checkpoint(raw_checkpoint_path) assert obj2.hparams.test_arg == 14 assert isinstance(obj2.hparams, hparam_type) if try_overwrite: # verify that we can overwrite the property - obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78, weights_only=weights_only) + obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78) assert obj3.hparams.test_arg == 78 return raw_checkpoint_path @@ -176,7 +176,7 @@ def test_omega_conf_hparams(tmp_path, cls): assert isinstance(obj.hparams, Container) # run standard test suite - raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule, weights_only=False) + raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule) obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=False) assert isinstance(obj2.hparams, Container) From 561c02c63554769c622bbde4340803472980313c Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Tue, 19 Aug 2025 12:48:22 -0400 Subject: [PATCH 16/45] changelog --- src/lightning/fabric/CHANGELOG.md | 2 +- src/lightning/pytorch/CHANGELOG.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 20907df7c874a..2032a0205561a 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- Default to `weights_only=True` for `torch>=2.6` when loading checkpoints. ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) ### Fixed diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index e431c8eccd365..f73e053bb84f3 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed - Default to RichProgressBar and RichModelSummary if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580)) +- Default to `weights_only=True` for `torch>=2.6` when loading checkpoints. ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) ### Removed From c67c8a36d97635c1eaedbc8fa3f9e52aac6fe6e9 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Tue, 19 Aug 2025 15:07:06 -0400 Subject: [PATCH 17/45] ignore torch.load futurewarning --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index b4d5d0b1638f5..6c3a8c1ce9845 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,6 +182,7 @@ filterwarnings = [ # "error::DeprecationWarning", "error::FutureWarning", "ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated + "ignore:You are using `torch.load` with `weights_only=False` (the current default value).*:FutureWarning", ] xfail_strict = true junit_duration_report = "call" From 005c43985a873156338f8e65e0b870e17098c406 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Tue, 19 Aug 2025 15:51:14 -0400 Subject: [PATCH 18/45] add .* --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6c3a8c1ce9845..46f40d75b1d6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,7 +182,7 @@ filterwarnings = [ # "error::DeprecationWarning", "error::FutureWarning", "ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated - "ignore:You are using `torch.load` with `weights_only=False` (the current default value).*:FutureWarning", + "ignore:.*You are using `torch.load` with `weights_only=False` (the current default value).*:FutureWarning", ] xfail_strict = true junit_duration_report = "call" From 2ab89a27f508b75de64e494cf92ff5cfcc6c969a Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Tue, 19 Aug 2025 16:17:06 -0400 Subject: [PATCH 19/45] will this woork --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 46f40d75b1d6b..499fcbf8d0c37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,7 +182,7 @@ filterwarnings = [ # "error::DeprecationWarning", "error::FutureWarning", "ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated - "ignore:.*You are using `torch.load` with `weights_only=False` (the current default value).*:FutureWarning", + "ignore:You are using `torch.load` with `weights_only=False`.*:FutureWarning", ] xfail_strict = true junit_duration_report = "call" From 74e5e5a3b2ddc7faafa7544bd34a68c7d5fe4872 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 20 Aug 2025 14:07:00 -0400 Subject: [PATCH 20/45] weights_only according pl version --- .../tests_pytorch/checkpointing/test_legacy_checkpoints.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 61816c1a12877..8367f4e57f927 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -46,7 +46,12 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str): assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] - model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24) + if pl_version == "local": + pl_version = pl.__version__ + + weights_only = Version(pl_version) >= Version("1.5.0") + + model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24, weights_only=weights_only) trainer = Trainer(default_root_dir=tmp_path) dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8) res = trainer.test(model, datamodule=dm) From a4c9efef7b34d1346981f73035c9ec586e3bcd58 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 20 Aug 2025 18:18:08 -0400 Subject: [PATCH 21/45] set env var TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 for pl < 1.5.0 --- src/lightning/fabric/utilities/cloud_io.py | 6 ------ .../tests_pytorch/checkpointing/test_legacy_checkpoints.py | 7 ++++++- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 8506ddcdf9f19..51e4729b1062f 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -26,7 +26,6 @@ from fsspec.implementations.local import AbstractFileSystem from lightning_utilities.core.imports import module_available -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH log = logging.getLogger(__name__) @@ -49,11 +48,6 @@ def _load( `PyTorch Developer Notes on Serialization Semantics `_. """ - # default to `weights_only=True` for torch>=2.6 - if weights_only is None and _TORCH_GREATER_EQUAL_2_6: - log.debug("Defaulting to `weights_only=True` for torch>=2.6.") - weights_only = True - if not isinstance(path_or_url, (str, Path)): # any sort of BytesIO or similar return torch.load( diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 8367f4e57f927..a5ad77cf25c1a 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -105,9 +105,14 @@ def load_model(): @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @RunIf(sklearn=True) -def test_resume_legacy_checkpoints(tmp_path, pl_version: str): +def test_resume_legacy_checkpoints(monkeypatch, tmp_path, pl_version: str): PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) with patch("sys.path", [PATH_LEGACY] + sys.path): + if pl_version == "local": + pl_version = pl.__version__ + if Version(pl_version) < Version("1.5.0"): + monkeypatch.setenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1") + path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] From 2c2ab9e930ea4b5f422fd2349fcec01e34c3a5f5 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 20 Aug 2025 20:32:03 -0400 Subject: [PATCH 22/45] weights_only=False for omegaconf hparams test --- tests/tests_pytorch/models/test_hparams.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index da5d75273453d..5f333314327fb 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -19,6 +19,7 @@ from argparse import Namespace from dataclasses import dataclass, field from enum import Enum +from typing import Optional from unittest import mock import cloudpickle @@ -94,7 +95,9 @@ def __init__(self, hparams, *my_args, **my_kwargs): # ------------------------- # STANDARD TESTS # ------------------------- -def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False): +def _run_standard_hparams_test( + tmp_path, model, cls, datamodule=None, try_overwrite=False, weights_only: Optional[bool] = None +): """Tests for the existence of an arg 'test_arg=14'.""" obj = datamodule if issubclass(cls, LightningDataModule) else model @@ -108,20 +111,20 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) - raw_checkpoint = torch.load(raw_checkpoint_path) + raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only) assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 # verify that model loads correctly - obj2 = cls.load_from_checkpoint(raw_checkpoint_path) + obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only) assert obj2.hparams.test_arg == 14 assert isinstance(obj2.hparams, hparam_type) if try_overwrite: # verify that we can overwrite the property - obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78) + obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78, weights_only=weights_only) assert obj3.hparams.test_arg == 78 return raw_checkpoint_path @@ -176,7 +179,8 @@ def test_omega_conf_hparams(tmp_path, cls): assert isinstance(obj.hparams, Container) # run standard test suite - raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule) + # weights_only=False as omegaconf.DictConfig is not an allowed global by default + raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule, weights_only=False) obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=False) assert isinstance(obj2.hparams, Container) From 54b859aaea4ac676ea7cac4549504183345a704f Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Thu, 21 Aug 2025 00:48:47 -0400 Subject: [PATCH 23/45] default to weights_only=true for loading from state_dict from url --- src/lightning/fabric/utilities/cloud_io.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 51e4729b1062f..33b6d8647fd8c 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -56,6 +56,9 @@ def _load( weights_only=weights_only, ) if str(path_or_url).startswith("http"): + if weights_only is None: + weights_only = True + log.debug(f"Default to `weights_only=True` for remote checkpoint: {path_or_url}") return torch.hub.load_state_dict_from_url( str(path_or_url), map_location=map_location, # type: ignore[arg-type] From 7ddb4f80fc218930322009b87d8148872011382c Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Thu, 21 Aug 2025 00:49:35 -0400 Subject: [PATCH 24/45] weights_only=False for hydra --- tests/tests_pytorch/models/test_hparams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 5f333314327fb..d0c72721ce1be 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -791,7 +791,7 @@ def __init__(self, args_0, args_1, args_2, kwarg_1=None): logger=False, ) trainer.fit(model) - _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path) + _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path, weights_only=False) @pytest.mark.parametrize("ignore", ["arg2", ("arg2", "arg3")]) From 906e52e13ea5d5c678d02eda21f8e6b4afd364f3 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 27 Aug 2025 16:49:45 -0400 Subject: [PATCH 25/45] Update src/lightning/fabric/utilities/cloud_io.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning/fabric/utilities/cloud_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 33b6d8647fd8c..a8c4007376b2b 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -56,7 +56,7 @@ def _load( weights_only=weights_only, ) if str(path_or_url).startswith("http"): - if weights_only is None: + if weights_only is None and _TORCH_GREATER_EQUAL_2_6: weights_only = True log.debug(f"Default to `weights_only=True` for remote checkpoint: {path_or_url}") return torch.hub.load_state_dict_from_url( From 170fbe06c1f7fd18690645928e2eb43b0de9b6ca Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 27 Aug 2025 16:53:46 -0400 Subject: [PATCH 26/45] defaults for weights_only in torch.hub.load_state_dict_from_url --- src/lightning/fabric/utilities/cloud_io.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index a8c4007376b2b..ab8efb11db31b 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -26,6 +26,7 @@ from fsspec.implementations.local import AbstractFileSystem from lightning_utilities.core.imports import module_available +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH log = logging.getLogger(__name__) @@ -56,9 +57,14 @@ def _load( weights_only=weights_only, ) if str(path_or_url).startswith("http"): - if weights_only is None and _TORCH_GREATER_EQUAL_2_6: - weights_only = True - log.debug(f"Default to `weights_only=True` for remote checkpoint: {path_or_url}") + if weights_only is None: + if _TORCH_GREATER_EQUAL_2_6: + weights_only = True + log.debug(f"Default to `weights_only=True` for remote checkpoint for torch>=2.6: {path_or_url}") + else: + weights_only = False + log.debug(f"Default to `weights_only=False` for remote checkpoint for torch<2.6: {path_or_url}") + return torch.hub.load_state_dict_from_url( str(path_or_url), map_location=map_location, # type: ignore[arg-type] From bf7e284968202f32ff3e20f659f5274fa8e056cc Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Tue, 2 Sep 2025 23:17:16 -0400 Subject: [PATCH 27/45] default to weights_only=False for torch.hub.load_state_dict_from_url --- src/lightning/fabric/utilities/cloud_io.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index ab8efb11db31b..54b18fb6ce3b0 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -26,7 +26,6 @@ from fsspec.implementations.local import AbstractFileSystem from lightning_utilities.core.imports import module_available -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH log = logging.getLogger(__name__) @@ -58,12 +57,11 @@ def _load( ) if str(path_or_url).startswith("http"): if weights_only is None: - if _TORCH_GREATER_EQUAL_2_6: - weights_only = True - log.debug(f"Default to `weights_only=True` for remote checkpoint for torch>=2.6: {path_or_url}") - else: - weights_only = False - log.debug(f"Default to `weights_only=False` for remote checkpoint for torch<2.6: {path_or_url}") + weights_only = False + log.debug( + f"Defaulting to `weights_only=False` for remote checkpoint: {path_or_url}." + f" If loading a checkpoint from an untrustted source, we recommend using `weights_only=True`." + ) return torch.hub.load_state_dict_from_url( str(path_or_url), From 5ac969539e47fc9e064b26d4d1087172c9b8cc31 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Sun, 7 Sep 2025 21:01:59 -0400 Subject: [PATCH 28/45] add weights_only to trainer.fit, validate, test, predict --- src/lightning/pytorch/strategies/strategy.py | 4 +-- .../connectors/checkpoint_connector.py | 10 +++--- src/lightning/pytorch/trainer/trainer.py | 35 +++++++++++++------ 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 16b16a4927513..4e5dc33f62672 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -363,9 +363,9 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" return self._lightning_module - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool) -> dict[str, Any]: torch.cuda.empty_cache() - return self.checkpoint_io.load_checkpoint(checkpoint_path) + return self.checkpoint_io.load_checkpoint(checkpoint_path, weights_only=weights_only) def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: assert self.lightning_module is not None diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 52fb0e3230a82..863fd265a6cfd 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -64,7 +64,7 @@ def _hpc_resume_path(self) -> Optional[str]: return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt" return None - def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: + def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: bool = False) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`. @@ -80,7 +80,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}") with pl_legacy_patch(): - loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) + loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path, weights_only) self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path) def _select_ckpt_path( @@ -403,9 +403,11 @@ def restore_lr_schedulers(self) -> None: for config, lrs_state in zip(self.trainer.lr_scheduler_configs, lr_schedulers): config.scheduler.load_state_dict(lrs_state) - def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None: + def _restore_modules_and_callbacks( + self, checkpoint_path: Optional[_PATH] = None, weights_only: bool = False + ) -> None: # restore modules after setup - self.resume_start(checkpoint_path) + self.resume_start(checkpoint_path, weights_only) self.restore_model() self.restore_datamodule() self.restore_callbacks() diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 5768c507e2e3f..739ecc369609f 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -526,6 +526,7 @@ def fit( val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, ) -> None: r"""Runs the full optimization routine. @@ -573,7 +574,14 @@ def fit( self.training = True self.should_stop = False call._call_and_handle_interrupt( - self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path + self, + self._fit_impl, + model, + train_dataloaders, + val_dataloaders, + datamodule, + ckpt_path, + weights_only, ) def _fit_impl( @@ -583,6 +591,7 @@ def _fit_impl( val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, ) -> None: log.debug(f"{self.__class__.__name__}: trainer fit stage") @@ -610,7 +619,7 @@ def _fit_impl( model_provided=True, model_connected=self.lightning_module is not None, ) - self._run(model, ckpt_path=ckpt_path) + self._run(model, ckpt_path=ckpt_path, weights_only=weights_only) assert self.state.stopped self.training = False @@ -621,6 +630,7 @@ def validate( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> _EVALUATE_OUTPUT: @@ -676,7 +686,7 @@ def validate( self.state.status = TrainerStatus.RUNNING self.validating = True return call._call_and_handle_interrupt( - self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule + self, self._validate_impl, model, dataloaders, ckpt_path, weights_only, verbose, datamodule ) def _validate_impl( @@ -684,6 +694,7 @@ def _validate_impl( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: @@ -717,7 +728,7 @@ def _validate_impl( ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - results = self._run(model, ckpt_path=ckpt_path) + results = self._run(model, ckpt_path=ckpt_path, weights_only=weights_only) # remove the tensors from the validation results results = convert_tensors_to_scalars(results) @@ -731,6 +742,7 @@ def test( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> _EVALUATE_OUTPUT: @@ -787,7 +799,7 @@ def test( self.state.status = TrainerStatus.RUNNING self.testing = True return call._call_and_handle_interrupt( - self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule + self, self._test_impl, model, dataloaders, ckpt_path, weights_only, verbose, datamodule ) def _test_impl( @@ -795,6 +807,7 @@ def _test_impl( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: @@ -828,7 +841,7 @@ def _test_impl( ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - results = self._run(model, ckpt_path=ckpt_path) + results = self._run(model, ckpt_path=ckpt_path, weights_only=weights_only) # remove the tensors from the test results results = convert_tensors_to_scalars(results) @@ -844,6 +857,7 @@ def predict( datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, ) -> Optional[_PREDICT_OUTPUT]: r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks. @@ -899,7 +913,7 @@ def predict( self.state.status = TrainerStatus.RUNNING self.predicting = True return call._call_and_handle_interrupt( - self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path + self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path, weights_only ) def _predict_impl( @@ -909,6 +923,7 @@ def _predict_impl( datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, ) -> Optional[_PREDICT_OUTPUT]: # -------------------- # SETUP HOOK @@ -939,7 +954,7 @@ def _predict_impl( ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - results = self._run(model, ckpt_path=ckpt_path) + results = self._run(model, ckpt_path=ckpt_path, weights_only=weights_only) assert self.state.stopped self.predicting = False @@ -947,7 +962,7 @@ def _predict_impl( return results def _run( - self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None + self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None, weights_only: bool = False ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( @@ -992,7 +1007,7 @@ def _run( # check if we should delay restoring checkpoint till later if not self.strategy.restore_checkpoint_after_setup: log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") - self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path) + self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path, weights_only) # reset logger connector self._logger_connector.reset_results() From ea066a993d76b4210e194991b464ea220ed09630 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Sun, 7 Sep 2025 22:12:15 -0400 Subject: [PATCH 29/45] fix tests --- tests/tests_pytorch/accelerators/test_cpu.py | 4 +-- .../plugins/test_checkpoint_io_plugin.py | 4 +-- tests/tests_pytorch/test_cli.py | 26 ++++++++++++------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py index ec2aabf559dc7..420f711678808 100644 --- a/tests/tests_pytorch/accelerators/test_cpu.py +++ b/tests/tests_pytorch/accelerators/test_cpu.py @@ -53,9 +53,9 @@ def setup(self, trainer: "pl.Trainer") -> None: def restore_checkpoint_after_setup(self) -> bool: return restore_after_pre_setup - def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: Union[str, Path], weights_only: bool) -> dict[str, Any]: assert self.setup_called == restore_after_pre_setup - return super().load_checkpoint(checkpoint_path) + return super().load_checkpoint(checkpoint_path, weights_only) model = BoringModel() trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index ef4c652d08660..c6be86baf3bc6 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -69,7 +69,7 @@ def test_checkpoint_plugin_called(tmp_path): assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) - checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last.ckpt")) + checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last.ckpt"), weights_only=False) checkpoint_plugin.reset_mock() ck = ModelCheckpoint(dirpath=tmp_path, save_last=True) @@ -97,7 +97,7 @@ def test_checkpoint_plugin_called(tmp_path): trainer.test(model, ckpt_path=ck.last_model_path) checkpoint_plugin.load_checkpoint.assert_called_once() - checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last-v1.ckpt")) + checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last-v1.ckpt"), weights_only=False) @pytest.mark.flaky(reruns=3) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 248852f4cf1f3..c25f75658341b 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1222,7 +1222,7 @@ def test_lightning_cli_model_short_arguments(): ): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) - run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY) + run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY, ANY) with ( mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), @@ -1250,7 +1250,7 @@ def test_lightning_cli_datamodule_short_arguments(): ): cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.datamodule, BoringDataModule) - run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY) + run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY, ANY) with ( mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), @@ -1271,7 +1271,7 @@ def test_lightning_cli_datamodule_short_arguments(): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) assert isinstance(cli.datamodule, BoringDataModule) - run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY) + run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY, ANY) with ( mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), @@ -1447,7 +1447,7 @@ def test_lightning_cli_config_with_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") + test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar", weights_only=False) assert cli.trainer.limit_test_batches == 1 @@ -1463,7 +1463,9 @@ def test_lightning_cli_config_before_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") + test_mock.assert_called_once_with( + cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar", weights_only=False + ) assert cli.trainer.limit_test_batches == 1 save_config_callback = cli.trainer.callbacks[0] @@ -1476,7 +1478,7 @@ def test_lightning_cli_config_before_subcommand(): ): cli = LightningCLI(BoringModel) - validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") + validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo", weights_only=False) assert cli.trainer.limit_val_batches == 1 save_config_callback = cli.trainer.callbacks[0] @@ -1494,7 +1496,9 @@ def test_lightning_cli_config_before_subcommand_two_configs(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") + test_mock.assert_called_once_with( + cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar", weights_only=False + ) assert cli.trainer.limit_test_batches == 1 with ( @@ -1503,7 +1507,7 @@ def test_lightning_cli_config_before_subcommand_two_configs(): ): cli = LightningCLI(BoringModel) - validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") + validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo", weights_only=False) assert cli.trainer.limit_val_batches == 1 @@ -1515,7 +1519,7 @@ def test_lightning_cli_config_after_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") + test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar", weights_only=False) assert cli.trainer.limit_test_batches == 1 @@ -1528,7 +1532,9 @@ def test_lightning_cli_config_before_and_after_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar") + test_mock.assert_called_once_with( + cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar", weights_only=False + ) assert cli.trainer.limit_test_batches == 1 assert cli.trainer.fast_dev_run == 1 From 4d65777d7a5cf775a0600865acd410b1f9f9017a Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Sun, 7 Sep 2025 22:53:11 -0400 Subject: [PATCH 30/45] add weights_only arg --- src/lightning/fabric/strategies/strategy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 0b2b373acb7bc..985d68c686cc2 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -310,6 +310,7 @@ def load_checkpoint( path: _PATH, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, + weights_only: bool = False, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. @@ -330,7 +331,7 @@ def load_checkpoint( """ torch.cuda.empty_cache() - checkpoint = self.checkpoint_io.load_checkpoint(path) + checkpoint = self.checkpoint_io.load_checkpoint(path, weights_only) if not state: return checkpoint From b84a53d1354eb97736c15e843f0c3916aba9a0d9 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Mon, 8 Sep 2025 08:33:45 -0400 Subject: [PATCH 31/45] specify weights_only kwarg --- src/lightning/fabric/strategies/strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 985d68c686cc2..fe766e0432b8e 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -331,7 +331,7 @@ def load_checkpoint( """ torch.cuda.empty_cache() - checkpoint = self.checkpoint_io.load_checkpoint(path, weights_only) + checkpoint = self.checkpoint_io.load_checkpoint(path, weights_only=weights_only) if not state: return checkpoint From ea6773e672845392315bc39551fa2799f4606bfc Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Mon, 8 Sep 2025 17:16:07 -0400 Subject: [PATCH 32/45] weights_only for fsdp load --- src/lightning/fabric/strategies/model_parallel.py | 5 +++-- src/lightning/pytorch/strategies/model_parallel.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 0d49ddf91a0bc..cc47fb3f7e360 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -411,6 +411,7 @@ def _load_checkpoint( state: dict[str, Union[Module, Optimizer, Any]], strict: bool = True, optimizer_states_from_list: bool = False, + weights_only: bool = False, ) -> dict[str, Any]: from torch.distributed.checkpoint.state_dict import ( StateDictOptions, @@ -449,7 +450,7 @@ def _load_checkpoint( set_optimizer_state_dict(module, optim, optim_state_dict=optim_state[optim_key], options=state_dict_options) # Load metadata (anything not a module or optimizer) - metadata = torch.load(path / _METADATA_FILENAME) + metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only) requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict) for key in requested_metadata_keys: @@ -461,7 +462,7 @@ def _load_checkpoint( return metadata if _is_full_checkpoint(path): - checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=False) + checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=weights_only) _load_raw_module_state(checkpoint.pop(module_key), module, strict=strict) state_dict_options = StateDictOptions( diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index e0286dbe2e0e6..def87061a9b73 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -329,7 +329,7 @@ def save_checkpoint( return super().save_checkpoint(checkpoint=checkpoint, filepath=path) @override - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool = False) -> dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) state = { @@ -342,6 +342,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: state=state, strict=self.lightning_module.strict_loading, optimizer_states_from_list=True, + weights_only=weights_only, ) def _setup_distributed(self) -> None: From 23c81e8d01cbb5138f6b2ebbed1adfe7b395b7f5 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 10 Sep 2025 13:27:13 +0200 Subject: [PATCH 33/45] Apply suggestions from code review --- src/lightning/fabric/CHANGELOG.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index f6e4047c04649..406b3a687f14c 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -9,7 +9,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added `exclude_frozen_parameters` to `DeepSpeedStrategy` ([#21060](https://github.com/Lightning-AI/pytorch-lightning/pull/21060)) - @@ -21,7 +20,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed - Default to `weights_only=True` for `torch>=2.6` when loading checkpoints. ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) -- let `_get_default_process_group_backend_for_device` support more hardware platforms ([#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093)) - Set `_DeviceDtypeModuleMixin._device` from torch's default device function ([#21164](https://github.com/Lightning-AI/pytorch-lightning/pull/21164)) From 2de885c6300a3d913f5ad7b011da40319c359bba Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 10 Sep 2025 13:27:45 +0200 Subject: [PATCH 34/45] Apply suggestions from code review --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c2d754b2250fc..499fcbf8d0c37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -183,7 +183,6 @@ filterwarnings = [ "error::FutureWarning", "ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated "ignore:You are using `torch.load` with `weights_only=False`.*:FutureWarning", - "ignore:The pynvml package is deprecated:FutureWarning", # Ignore pynvml deprecation warning, since it is not installed by PL directly ] xfail_strict = true junit_duration_report = "call" From 9f709d64e98d33f7fc55eaeec0f7943511c35704 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 10 Sep 2025 07:53:18 -0400 Subject: [PATCH 35/45] default is none --- src/lightning/fabric/strategies/model_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index cc47fb3f7e360..f31c923e23f6f 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -411,7 +411,7 @@ def _load_checkpoint( state: dict[str, Union[Module, Optimizer, Any]], strict: bool = True, optimizer_states_from_list: bool = False, - weights_only: bool = False, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: from torch.distributed.checkpoint.state_dict import ( StateDictOptions, From 653dd6f8f81d0deb434b70558cf68677f0ee4b3d Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 10 Sep 2025 09:15:06 -0400 Subject: [PATCH 36/45] add weights_only args to strategies --- src/lightning/fabric/strategies/deepspeed.py | 3 ++- src/lightning/fabric/strategies/fsdp.py | 3 ++- src/lightning/fabric/strategies/model_parallel.py | 3 ++- src/lightning/fabric/strategies/strategy.py | 2 +- src/lightning/fabric/strategies/xla_fsdp.py | 3 ++- src/lightning/pytorch/strategies/deepspeed.py | 4 ++-- src/lightning/pytorch/strategies/fsdp.py | 4 ++-- src/lightning/pytorch/strategies/strategy.py | 2 +- 8 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index fe72db20e2b85..209ddaaf57548 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -458,6 +458,7 @@ def load_checkpoint( path: _PATH, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. @@ -483,7 +484,7 @@ def load_checkpoint( # This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from # a consolidated checkpoint path = self.broadcast(path) - return super().load_checkpoint(path=path, state=state, strict=strict) + return super().load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only) if not state: raise ValueError( diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index baaee74af0ec9..f42ade7484395 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -516,6 +516,7 @@ def load_checkpoint( path: _PATH, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects.""" if not state: @@ -586,7 +587,7 @@ def load_checkpoint( optim.load_state_dict(flattened_osd) # Load metadata (anything not a module or optimizer) - metadata = torch.load(path / _METADATA_FILENAME) + metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only) requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict) for key in requested_metadata_keys: diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index f31c923e23f6f..677584668975e 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -275,6 +275,7 @@ def load_checkpoint( path: _PATH, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects.""" if not state: @@ -295,7 +296,7 @@ def load_checkpoint( f"Loading a single optimizer object from a checkpoint is not supported yet with {type(self).__name__}." ) - return _load_checkpoint(path=path, state=state, strict=strict) + return _load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only) def _setup_distributed(self) -> None: reset_seed() diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index fe766e0432b8e..b368f626c3b11 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -310,7 +310,7 @@ def load_checkpoint( path: _PATH, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - weights_only: bool = False, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 3fa9e40f4b4bd..51b528eff26ff 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -516,6 +516,7 @@ def load_checkpoint( path: _PATH, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Given a folder, load the contents from a checkpoint and restore the state of the given objects. @@ -608,7 +609,7 @@ def load_checkpoint( ) if "model" not in state or not isinstance(model := state["model"], torch.nn.Module): raise NotImplementedError("XLAFSDP only supports a single model instance with 'model' as the key.") - full_ckpt = torch.load(path) + full_ckpt = torch.load(path, weights_only=weights_only) model.load_state_dict(full_ckpt.pop("model"), strict=strict) return full_ckpt diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index c5253f77cdedb..369360590878d 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -659,12 +659,12 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op ) @override - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing checkpoint_path = self.broadcast(checkpoint_path) - return super().load_checkpoint(checkpoint_path) + return super().load_checkpoint(checkpoint_path, weights_only) _validate_checkpoint_directory(checkpoint_path) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 3fbd0f9cd5f0a..9706c8a64e61b 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -583,7 +583,7 @@ def save_checkpoint( raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") @override - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) @@ -624,7 +624,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: optim.load_state_dict(flattened_osd) # Load metadata (anything not a module or optimizer) - metadata = torch.load(path / _METADATA_FILENAME) + metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only) return metadata if _is_full_checkpoint(path): diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 4e5dc33f62672..0a00cb28af15e 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -363,7 +363,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" return self._lightning_module - def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path, weights_only=weights_only) From 46cc788f3d0019e695a441aec015bfacfe0c4afa Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 10 Sep 2025 09:48:50 -0400 Subject: [PATCH 37/45] trainer default to weights_only=None --- .../pytorch/strategies/model_parallel.py | 2 +- .../connectors/checkpoint_connector.py | 4 ++-- src/lightning/pytorch/trainer/trainer.py | 23 +++++++++++-------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index def87061a9b73..f3165a08e6bdd 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -329,7 +329,7 @@ def save_checkpoint( return super().save_checkpoint(checkpoint=checkpoint, filepath=path) @override - def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool = False) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) state = { diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 863fd265a6cfd..b99185e1ada3e 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -64,7 +64,7 @@ def _hpc_resume_path(self) -> Optional[str]: return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt" return None - def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: bool = False) -> None: + def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`. @@ -404,7 +404,7 @@ def restore_lr_schedulers(self) -> None: config.scheduler.load_state_dict(lrs_state) def _restore_modules_and_callbacks( - self, checkpoint_path: Optional[_PATH] = None, weights_only: bool = False + self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None ) -> None: # restore modules after setup self.resume_start(checkpoint_path, weights_only) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 739ecc369609f..5e57f702076ff 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -526,7 +526,7 @@ def fit( val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, ckpt_path: Optional[_PATH] = None, - weights_only: bool = False, + weights_only: Optional[bool] = None, ) -> None: r"""Runs the full optimization routine. @@ -591,7 +591,7 @@ def _fit_impl( val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, ckpt_path: Optional[_PATH] = None, - weights_only: bool = False, + weights_only: Optional[bool] = None, ) -> None: log.debug(f"{self.__class__.__name__}: trainer fit stage") @@ -630,7 +630,7 @@ def validate( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, - weights_only: bool = False, + weights_only: Optional[bool] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> _EVALUATE_OUTPUT: @@ -694,7 +694,7 @@ def _validate_impl( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, - weights_only: bool = False, + weights_only: Optional[bool] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: @@ -742,7 +742,7 @@ def test( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, - weights_only: bool = False, + weights_only: Optional[bool] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> _EVALUATE_OUTPUT: @@ -807,7 +807,7 @@ def _test_impl( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, - weights_only: bool = False, + weights_only: Optional[bool] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: @@ -857,7 +857,7 @@ def predict( datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ckpt_path: Optional[_PATH] = None, - weights_only: bool = False, + weights_only: Optional[bool] = None, ) -> Optional[_PREDICT_OUTPUT]: r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks. @@ -923,7 +923,7 @@ def _predict_impl( datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ckpt_path: Optional[_PATH] = None, - weights_only: bool = False, + weights_only: Optional[bool] = None, ) -> Optional[_PREDICT_OUTPUT]: # -------------------- # SETUP HOOK @@ -962,7 +962,10 @@ def _predict_impl( return results def _run( - self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None, weights_only: bool = False + self, + model: "pl.LightningModule", + ckpt_path: Optional[_PATH] = None, + weights_only: Optional[bool] = None, ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( @@ -1401,7 +1404,7 @@ def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None: self._checkpoint_connector._user_managed = bool(ckpt_path) def save_checkpoint( - self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None + self, filepath: _PATH, weights_only: Optional[bool] = None, storage_options: Optional[Any] = None ) -> None: r"""Runs routine to create a checkpoint. From f5f5e58b075c58cfa5bd1e4e87aa079a59db8ef1 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 10 Sep 2025 13:01:15 -0400 Subject: [PATCH 38/45] wip: fix typing dump_checkpoint --- .../pytorch/trainer/connectors/checkpoint_connector.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index b99185e1ada3e..26cfd2024447a 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -412,7 +412,7 @@ def _restore_modules_and_callbacks( self.restore_datamodule() self.restore_callbacks() - def dump_checkpoint(self, weights_only: bool = False) -> dict: + def dump_checkpoint(self, weights_only: Optional[bool] = None) -> dict: """Creating a model checkpoint dictionary object from various component states. Args: @@ -450,6 +450,10 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: "loops": self._get_loops_state_dict(), } + if weights_only is None: + weights_only = False + log.info("`weights_only` was not set, defaulting to `False`.") + if not weights_only: # dump callbacks checkpoint["callbacks"] = call._call_callbacks_state_dict(trainer) From 97695ae1169a01ee5ec8fdfdd6a730023f7926c2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 10 Sep 2025 22:26:23 +0200 Subject: [PATCH 39/45] Apply suggestions from code review --- .../pytorch/trainer/connectors/checkpoint_connector.py | 4 ++-- src/lightning/pytorch/trainer/trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 26cfd2024447a..8158cf406544e 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -80,7 +80,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: Op rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}") with pl_legacy_patch(): - loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path, weights_only) + loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path, weights_only=weights_only) self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path) def _select_ckpt_path( @@ -407,7 +407,7 @@ def _restore_modules_and_callbacks( self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None ) -> None: # restore modules after setup - self.resume_start(checkpoint_path, weights_only) + self.resume_start(checkpoint_path, weights_only=weights_only) self.restore_model() self.restore_datamodule() self.restore_callbacks() diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 5e57f702076ff..226271c12f079 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -913,7 +913,7 @@ def predict( self.state.status = TrainerStatus.RUNNING self.predicting = True return call._call_and_handle_interrupt( - self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path, weights_only + self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path, weights_only=weights_only ) def _predict_impl( From 75ad8657a2174f1a0c51f0f2c8574f166eabde21 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Sep 2025 20:27:15 +0000 Subject: [PATCH 40/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/trainer/trainer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 226271c12f079..2b0b714539e5d 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -913,7 +913,14 @@ def predict( self.state.status = TrainerStatus.RUNNING self.predicting = True return call._call_and_handle_interrupt( - self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path, weights_only=weights_only + self, + self._predict_impl, + model, + dataloaders, + datamodule, + return_predictions, + ckpt_path, + weights_only=weights_only, ) def _predict_impl( From 436552d00e898012aefc2bf986fcf96539fc3315 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 10 Sep 2025 17:13:02 -0400 Subject: [PATCH 41/45] weights_only as last arg --- .../connectors/checkpoint_connector.py | 4 +- src/lightning/pytorch/trainer/trainer.py | 38 +++++++++++++++---- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 8158cf406544e..ae5038b2022d2 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -230,7 +230,7 @@ def resume_end(self) -> None: # wait for all to catch up self.trainer.strategy.barrier("_CheckpointConnector.resume_end") - def restore(self, checkpoint_path: Optional[_PATH] = None) -> None: + def restore(self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None) -> None: """Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore, in this priority: @@ -244,7 +244,7 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None: checkpoint_path: Path to a PyTorch Lightning checkpoint file. """ - self.resume_start(checkpoint_path) + self.resume_start(checkpoint_path, weights_only=weights_only) # restore module states self.restore_datamodule() diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 2b0b714539e5d..f2f59e396ab23 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -557,6 +557,12 @@ def fit( - ``'registry:version:v2'``: uses the default model set with ``Trainer(..., model_registry="my-model")`` and version 'v2' + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. + Raises: TypeError: If ``model`` is not :class:`~lightning.pytorch.core.LightningModule` for torch version less than @@ -630,9 +636,9 @@ def validate( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, - weights_only: Optional[bool] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, + weights_only: Optional[bool] = None, ) -> _EVALUATE_OUTPUT: r"""Perform one evaluation epoch over the validation set. @@ -653,6 +659,12 @@ def validate( datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. + For more information about multiple dataloaders, see this :ref:`section `. Returns: @@ -686,7 +698,7 @@ def validate( self.state.status = TrainerStatus.RUNNING self.validating = True return call._call_and_handle_interrupt( - self, self._validate_impl, model, dataloaders, ckpt_path, weights_only, verbose, datamodule + self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule, weights_only ) def _validate_impl( @@ -694,9 +706,9 @@ def _validate_impl( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, - weights_only: Optional[bool] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, + weights_only: Optional[bool] = None, ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: # -------------------- # SETUP HOOK @@ -742,9 +754,9 @@ def test( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, - weights_only: Optional[bool] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, + weights_only: Optional[bool] = None, ) -> _EVALUATE_OUTPUT: r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your test set until you want to. @@ -766,6 +778,12 @@ def test( datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. + For more information about multiple dataloaders, see this :ref:`section `. Returns: @@ -799,7 +817,7 @@ def test( self.state.status = TrainerStatus.RUNNING self.testing = True return call._call_and_handle_interrupt( - self, self._test_impl, model, dataloaders, ckpt_path, weights_only, verbose, datamodule + self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule, weights_only ) def _test_impl( @@ -807,9 +825,9 @@ def _test_impl( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, - weights_only: Optional[bool] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, + weights_only: Optional[bool] = None, ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: # -------------------- # SETUP HOOK @@ -880,6 +898,12 @@ def predict( Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. + For more information about multiple dataloaders, see this :ref:`section `. Returns: @@ -920,7 +944,7 @@ def predict( datamodule, return_predictions, ckpt_path, - weights_only=weights_only, + weights_only, ) def _predict_impl( From d61b9ecfa25bf808ff0a37cce9737b4b539893ea Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 10 Sep 2025 17:44:42 -0400 Subject: [PATCH 42/45] asset called with none --- .../plugins/test_checkpoint_io_plugin.py | 4 ++-- tests/tests_pytorch/test_cli.py | 18 +++++++----------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index c6be86baf3bc6..0f66f215f6864 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -69,7 +69,7 @@ def test_checkpoint_plugin_called(tmp_path): assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) - checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last.ckpt"), weights_only=False) + checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last.ckpt"), weights_only=None) checkpoint_plugin.reset_mock() ck = ModelCheckpoint(dirpath=tmp_path, save_last=True) @@ -97,7 +97,7 @@ def test_checkpoint_plugin_called(tmp_path): trainer.test(model, ckpt_path=ck.last_model_path) checkpoint_plugin.load_checkpoint.assert_called_once() - checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last-v1.ckpt"), weights_only=False) + checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last-v1.ckpt"), weights_only=None) @pytest.mark.flaky(reruns=3) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index a65092502014b..6812bc42c534a 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1447,7 +1447,7 @@ def test_lightning_cli_config_with_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar", weights_only=False) + test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar", weights_only=None) assert cli.trainer.limit_test_batches == 1 @@ -1463,9 +1463,7 @@ def test_lightning_cli_config_before_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with( - cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar", weights_only=False - ) + test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar", weights_only=None) assert cli.trainer.limit_test_batches == 1 save_config_callback = cli.trainer.callbacks[0] @@ -1478,7 +1476,7 @@ def test_lightning_cli_config_before_subcommand(): ): cli = LightningCLI(BoringModel) - validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo", weights_only=False) + validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo", weights_only=None) assert cli.trainer.limit_val_batches == 1 save_config_callback = cli.trainer.callbacks[0] @@ -1496,9 +1494,7 @@ def test_lightning_cli_config_before_subcommand_two_configs(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with( - cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar", weights_only=False - ) + test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar", weights_only=None) assert cli.trainer.limit_test_batches == 1 with ( @@ -1507,7 +1503,7 @@ def test_lightning_cli_config_before_subcommand_two_configs(): ): cli = LightningCLI(BoringModel) - validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo", weights_only=False) + validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo", weights_only=None) assert cli.trainer.limit_val_batches == 1 @@ -1519,7 +1515,7 @@ def test_lightning_cli_config_after_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar", weights_only=False) + test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar", weights_only=None) assert cli.trainer.limit_test_batches == 1 @@ -1533,7 +1529,7 @@ def test_lightning_cli_config_before_and_after_subcommand(): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with( - cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar", weights_only=False + cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar", weights_only=None ) assert cli.trainer.limit_test_batches == 1 assert cli.trainer.fast_dev_run == 1 From 42559fe29f23f3e94cfd98505e5bfb14ce70d746 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 10 Sep 2025 19:16:14 -0400 Subject: [PATCH 43/45] weights_only=False for torch>=2.6 in tests --- tests/tests_pytorch/callbacks/test_callbacks.py | 4 +++- .../callbacks/test_stochastic_weight_avg.py | 9 ++++++--- tests/tests_pytorch/models/test_hparams.py | 6 ++++-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_callbacks.py b/tests/tests_pytorch/callbacks/test_callbacks.py index 34749087bfb97..fec835f199e0b 100644 --- a/tests/tests_pytorch/callbacks/test_callbacks.py +++ b/tests/tests_pytorch/callbacks/test_callbacks.py @@ -18,6 +18,7 @@ import pytest from lightning_utilities.test.warning import no_warning_call +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.pytorch import Callback, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel @@ -132,7 +133,8 @@ def test_resume_callback_state_saved_by_type_stateful(tmp_path): callback = OldStatefulCallback(state=222) trainer = Trainer(default_root_dir=tmp_path, max_steps=2, callbacks=[callback]) - trainer.fit(model, ckpt_path=ckpt_path) + weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None + trainer.fit(model, ckpt_path=ckpt_path, weights_only=weights_only) assert callback.state == 111 diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index abcd302149fcf..5786a3339a7fa 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -25,6 +25,7 @@ from torch.optim.swa_utils import SWALR from torch.utils.data import DataLoader +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.pytorch import Trainer from lightning.pytorch.callbacks import StochasticWeightAveraging from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset @@ -173,8 +174,9 @@ def train_with_swa( devices=devices, ) + weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None with _backward_patch(trainer): - trainer.fit(model) + trainer.fit(model, weights_only=weights_only) # check the model is the expected assert trainer.lightning_module == model @@ -307,8 +309,9 @@ def _swa_resume_training_from_checkpoint(tmp_path, model, resume_model, ddp=Fals } trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) + weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None with _backward_patch(trainer), pytest.raises(Exception, match="SWA crash test"): - trainer.fit(model) + trainer.fit(model, weights_only=weights_only) checkpoint_dir = Path(tmp_path) / "checkpoints" checkpoint_files = os.listdir(checkpoint_dir) @@ -318,7 +321,7 @@ def _swa_resume_training_from_checkpoint(tmp_path, model, resume_model, ddp=Fals trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) with _backward_patch(trainer): - trainer.fit(resume_model, ckpt_path=ckpt_path) + trainer.fit(resume_model, ckpt_path=ckpt_path, weights_only=weights_only) class CustomSchedulerModel(SwaTestModel): diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index d0c72721ce1be..da68a240ec4ac 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -30,6 +30,7 @@ from lightning_utilities.test.warning import no_warning_call from torch.utils.data import DataLoader +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.core.datamodule import LightningDataModule @@ -748,8 +749,9 @@ def test_model_with_fsspec_as_parameter(tmp_path): trainer = Trainer( default_root_dir=tmp_path, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1 ) - trainer.fit(model) - trainer.test() + weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None + trainer.fit(model, weights_only=weights_only) + trainer.test(weights_only=weights_only) @pytest.mark.xfail( From 605b11c4a75408f25fd5b4e7eaf1a36085c3a22c Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 10 Sep 2025 20:21:50 -0400 Subject: [PATCH 44/45] fix changelog description --- src/lightning/fabric/CHANGELOG.md | 4 +--- src/lightning/pytorch/CHANGELOG.md | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 406b3a687f14c..307975e1619bd 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -19,9 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Default to `weights_only=True` for `torch>=2.6` when loading checkpoints. ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) - - +- Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) - Set `_DeviceDtypeModuleMixin._device` from torch's default device function ([#21164](https://github.com/Lightning-AI/pytorch-lightning/pull/21164)) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 099f091c3a379..92dff1de1a6ed 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -27,7 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Default to `weights_only=True` for `torch>=2.6` when loading checkpoints. ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) +- Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896)) From 36c419ba41d41fb9924c89e3dfb199e33ce561c1 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Thu, 11 Sep 2025 15:53:07 +0200 Subject: [PATCH 45/45] Empty-Commit