Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
074b01e
change weights_only default to True
matsumotosan Aug 14, 2025
65cc1ed
add docs on weights_only arg
matsumotosan Aug 14, 2025
4eaaf58
Merge branch 'master' into weights-only-compatibility
SkafteNicki Aug 15, 2025
f276114
add weights_only arg to checkpoint save. weights_only during test set…
matsumotosan Aug 15, 2025
601e300
Merge branch 'master' into weights-only-compatibility
matsumotosan Aug 15, 2025
28f53ae
add weights_only arg to checkpoint_io
matsumotosan Aug 16, 2025
b1cfdf1
woops, reverting changes
matsumotosan Aug 16, 2025
4d96a78
permissions too
matsumotosan Aug 16, 2025
4c39c30
fix link
matsumotosan Aug 16, 2025
861d7e0
fix another link
matsumotosan Aug 16, 2025
12bd0d6
datamodule weights_only args
matsumotosan Aug 17, 2025
5eacb6e
wip: try safe_globals context manager for tests
matsumotosan Aug 17, 2025
0430e22
add weights_only arg to _run_standard_hparams_test
matsumotosan Aug 18, 2025
2abe915
weights_only=False when adding extra_args
matsumotosan Aug 18, 2025
8e0f61e
Merge branch 'master' into weights-only-compatibility
Borda Aug 18, 2025
525d9a8
Merge branch 'master' into weights-only-compatibility
matsumotosan Aug 18, 2025
83fd824
switch to lightning_utilities.cli requirements set-oldest (#21077)
Borda Aug 19, 2025
93cbe94
bump: try `deepspeed >=0.14.1,<=0.15.0` (#21076)
Borda Aug 19, 2025
2a53f2f
weights_only=True default for torch>=2.6
matsumotosan Aug 19, 2025
3833892
Merge branch 'master' into weights-only-compatibility
matsumotosan Aug 19, 2025
9d8997e
Merge branch 'Lightning-AI:master' into weights-only-compatibility
matsumotosan Aug 19, 2025
561c02c
changelog
matsumotosan Aug 19, 2025
c67c8a3
ignore torch.load futurewarning
matsumotosan Aug 19, 2025
005c439
add .*
matsumotosan Aug 19, 2025
2ab89a2
will this woork
matsumotosan Aug 19, 2025
74e5e5a
weights_only according pl version
matsumotosan Aug 20, 2025
a4c9efe
set env var TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 for pl < 1.5.0
matsumotosan Aug 20, 2025
2c2ab9e
weights_only=False for omegaconf hparams test
matsumotosan Aug 21, 2025
54b859a
default to weights_only=true for loading from state_dict from url
matsumotosan Aug 21, 2025
7ddb4f8
weights_only=False for hydra
matsumotosan Aug 21, 2025
7d6174a
Merge branch 'master' into weights-only-compatibility
matsumotosan Aug 23, 2025
906e52e
Update src/lightning/fabric/utilities/cloud_io.py
matsumotosan Aug 27, 2025
377cf11
Merge branch 'Lightning-AI:master' into weights-only-compatibility
matsumotosan Aug 28, 2025
9aadbef
Merge branch 'master' into weights-only-compatibility
Borda Aug 29, 2025
52b3215
Merge branch 'master' into weights-only-compatibility
matsumotosan Aug 30, 2025
170fbe0
defaults for weights_only in torch.hub.load_state_dict_from_url
matsumotosan Aug 27, 2025
bf7e284
default to weights_only=False for torch.hub.load_state_dict_from_url
matsumotosan Sep 3, 2025
a5d6da3
Merge branch 'master' into weights-only-compatibility
Borda Sep 3, 2025
a3183ba
Merge branch 'master' into weights-only-compatibility
matsumotosan Sep 5, 2025
5ac9695
add weights_only to trainer.fit, validate, test, predict
matsumotosan Sep 8, 2025
ea066a9
fix tests
matsumotosan Sep 8, 2025
6b711c9
Merge branch 'master' into weights-only-compatibility
matsumotosan Sep 8, 2025
4d65777
add weights_only arg
matsumotosan Sep 8, 2025
087d467
Merge branch 'master' into weights-only-compatibility
matsumotosan Sep 8, 2025
b84a53d
specify weights_only kwarg
matsumotosan Sep 8, 2025
ea6773e
weights_only for fsdp load
matsumotosan Sep 8, 2025
0685799
Merge branch 'master' into weights-only-compatibility
matsumotosan Sep 8, 2025
d5431d4
Merge branch 'master' into weights-only-compatibility
Borda Sep 10, 2025
23c81e8
Apply suggestions from code review
Borda Sep 10, 2025
2de885c
Apply suggestions from code review
Borda Sep 10, 2025
9f709d6
default is none
matsumotosan Sep 10, 2025
a95f9ba
Merge branch 'master' into weights-only-compatibility
Borda Sep 10, 2025
653dd6f
add weights_only args to strategies
matsumotosan Sep 10, 2025
46cc788
trainer default to weights_only=None
matsumotosan Sep 10, 2025
f5f5e58
wip: fix typing dump_checkpoint
matsumotosan Sep 10, 2025
97695ae
Apply suggestions from code review
Borda Sep 10, 2025
75ad865
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
436552d
weights_only as last arg
matsumotosan Sep 10, 2025
d61b9ec
asset called with none
matsumotosan Sep 10, 2025
42559fe
weights_only=False for torch>=2.6 in tests
matsumotosan Sep 10, 2025
605b11c
fix changelog description
matsumotosan Sep 11, 2025
fa3eb90
Merge branch 'master' into weights-only-compatibility
tchaton Sep 11, 2025
36c419b
Empty-Commit
Borda Sep 11, 2025
b8cf6ea
Merge branch 'master' into weights-only-compatibility
Borda Sep 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ filterwarnings = [
# "error::DeprecationWarning",
"error::FutureWarning",
"ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated
"ignore:You are using `torch.load` with `weights_only=False`.*:FutureWarning",
]
xfail_strict = true
junit_duration_report = "call"
4 changes: 1 addition & 3 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-


- Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072))
- Set `_DeviceDtypeModuleMixin._device` from torch's default device function ([#21164](https://github.com/Lightning-AI/pytorch-lightning/pull/21164))


Expand Down
9 changes: 8 additions & 1 deletion src/lightning/fabric/plugins/io/checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,20 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio
"""

@abstractmethod
def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> dict[str, Any]:
def load_checkpoint(
self, path: _PATH, map_location: Optional[Any] = None, weights_only: Optional[bool] = None
) -> dict[str, Any]:
"""Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.

Args:
path: Path to checkpoint
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
locations.
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
recommend using ``weights_only=True``. For more information, please refer to the
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.

Returns: The loaded checkpoint.

Expand Down
12 changes: 10 additions & 2 deletions src/lightning/fabric/plugins/io/torch_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,22 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio

@override
def load_checkpoint(
self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage
self,
path: _PATH,
map_location: Optional[Callable] = lambda storage, loc: storage,
weights_only: Optional[bool] = None,
) -> dict[str, Any]:
"""Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files.

Args:
path: Path to checkpoint
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
locations.
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
recommend using ``weights_only=True``. For more information, please refer to the
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.

Returns: The loaded checkpoint.

Expand All @@ -80,7 +88,7 @@ def load_checkpoint(
if not fs.exists(path):
raise FileNotFoundError(f"Checkpoint file not found: {path}")

return pl_load(path, map_location=map_location)
return pl_load(path, map_location=map_location, weights_only=weights_only)

@override
def remove_checkpoint(self, path: _PATH) -> None:
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def load_checkpoint(
path: _PATH,
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
strict: bool = True,
weights_only: Optional[bool] = None,
) -> dict[str, Any]:
"""Load the contents from a checkpoint and restore the state of the given objects.

Expand All @@ -483,7 +484,7 @@ def load_checkpoint(
# This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
# a consolidated checkpoint
path = self.broadcast(path)
return super().load_checkpoint(path=path, state=state, strict=strict)
return super().load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only)

if not state:
raise ValueError(
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def load_checkpoint(
path: _PATH,
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
strict: bool = True,
weights_only: Optional[bool] = None,
) -> dict[str, Any]:
"""Load the contents from a checkpoint and restore the state of the given objects."""
if not state:
Expand Down Expand Up @@ -586,7 +587,7 @@ def load_checkpoint(
optim.load_state_dict(flattened_osd)

# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME)
metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only)
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
_validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict)
for key in requested_metadata_keys:
Expand Down
8 changes: 5 additions & 3 deletions src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def load_checkpoint(
path: _PATH,
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
strict: bool = True,
weights_only: Optional[bool] = None,
) -> dict[str, Any]:
"""Load the contents from a checkpoint and restore the state of the given objects."""
if not state:
Expand All @@ -295,7 +296,7 @@ def load_checkpoint(
f"Loading a single optimizer object from a checkpoint is not supported yet with {type(self).__name__}."
)

return _load_checkpoint(path=path, state=state, strict=strict)
return _load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only)

def _setup_distributed(self) -> None:
reset_seed()
Expand Down Expand Up @@ -411,6 +412,7 @@ def _load_checkpoint(
state: dict[str, Union[Module, Optimizer, Any]],
strict: bool = True,
optimizer_states_from_list: bool = False,
weights_only: Optional[bool] = None,
) -> dict[str, Any]:
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
Expand Down Expand Up @@ -449,7 +451,7 @@ def _load_checkpoint(
set_optimizer_state_dict(module, optim, optim_state_dict=optim_state[optim_key], options=state_dict_options)

# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME)
metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only)
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
_validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict)
for key in requested_metadata_keys:
Expand All @@ -461,7 +463,7 @@ def _load_checkpoint(
return metadata

if _is_full_checkpoint(path):
checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=False)
checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=weights_only)
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)

state_dict_options = StateDictOptions(
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def load_checkpoint(
path: _PATH,
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
strict: bool = True,
weights_only: Optional[bool] = None,
) -> dict[str, Any]:
"""Load the contents from a checkpoint and restore the state of the given objects.

Expand All @@ -330,7 +331,7 @@ def load_checkpoint(

"""
torch.cuda.empty_cache()
checkpoint = self.checkpoint_io.load_checkpoint(path)
checkpoint = self.checkpoint_io.load_checkpoint(path, weights_only=weights_only)
if not state:
return checkpoint

Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def load_checkpoint(
path: _PATH,
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
strict: bool = True,
weights_only: Optional[bool] = None,
) -> dict[str, Any]:
"""Given a folder, load the contents from a checkpoint and restore the state of the given objects.

Expand Down Expand Up @@ -608,7 +609,7 @@ def load_checkpoint(
)
if "model" not in state or not isinstance(model := state["model"], torch.nn.Module):
raise NotImplementedError("XLAFSDP only supports a single model instance with 'model' as the key.")
full_ckpt = torch.load(path)
full_ckpt = torch.load(path, weights_only=weights_only)
model.load_state_dict(full_ckpt.pop("model"), strict=strict)
return full_ckpt

Expand Down
18 changes: 15 additions & 3 deletions src/lightning/fabric/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import io
import logging
from pathlib import Path
from typing import IO, Any, Union
from typing import IO, Any, Optional, Union

import fsspec
import fsspec.utils
Expand All @@ -34,13 +34,18 @@
def _load(
path_or_url: Union[IO, _PATH],
map_location: _MAP_LOCATION_TYPE = None,
weights_only: bool = False,
weights_only: Optional[bool] = None,
) -> Any:
"""Loads a checkpoint.

Args:
path_or_url: Path or URL of the checkpoint.
map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations.
weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other primitive
types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use
``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using
``weights_only=True``. For more information, please refer to the
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.

"""
if not isinstance(path_or_url, (str, Path)):
Expand All @@ -51,6 +56,13 @@ def _load(
weights_only=weights_only,
)
if str(path_or_url).startswith("http"):
if weights_only is None:
weights_only = False
log.debug(
f"Defaulting to `weights_only=False` for remote checkpoint: {path_or_url}."
f" If loading a checkpoint from an untrustted source, we recommend using `weights_only=True`."
)

return torch.hub.load_state_dict_from_url(
str(path_or_url),
map_location=map_location, # type: ignore[arg-type]
Expand All @@ -70,7 +82,7 @@ def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem:
return fs


def _atomic_save(checkpoint: dict[str, Any], filepath: Union[str, Path]) -> None:
def _atomic_save(checkpoint: dict[str, Any], filepath: _PATH) -> None:
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.

Args:
Expand Down
1 change: 1 addition & 0 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0")
_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0")
_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0")
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072))


- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896))


Expand Down
7 changes: 7 additions & 0 deletions src/lightning/pytorch/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def load_from_checkpoint(
checkpoint_path: Union[_PATH, IO],
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[_PATH] = None,
weights_only: Optional[bool] = None,
**kwargs: Any,
) -> Self:
r"""Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the
Expand Down Expand Up @@ -206,6 +207,11 @@ def load_from_checkpoint(
If your datamodule's ``hparams`` argument is :class:`~argparse.Namespace`
and ``.yaml`` file has hierarchical structure, you need to refactor your datamodule to treat
``hparams`` as :class:`~dict`.
weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other
primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use
``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using
``weights_only=True``. For more information, please refer to the
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
\**kwargs: Any extra keyword args needed to init the datamodule. Can also be used to override saved
hyperparameter values.

Expand Down Expand Up @@ -242,6 +248,7 @@ def load_from_checkpoint(
map_location=map_location,
hparams_file=hparams_file,
strict=None,
weights_only=weights_only,
**kwargs,
)
return cast(Self, loaded)
Expand Down
7 changes: 7 additions & 0 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,7 @@ def load_from_checkpoint(
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[_PATH] = None,
strict: Optional[bool] = None,
weights_only: Optional[bool] = None,
**kwargs: Any,
) -> Self:
r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments
Expand Down Expand Up @@ -1723,6 +1724,11 @@ def load_from_checkpoint(
strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
returned by this module's state dict. Defaults to ``True`` unless ``LightningModule.strict_loading`` is
set, in which case it defaults to the value of ``LightningModule.strict_loading``.
weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other
primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use
``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using
``weights_only=True``. For more information, please refer to the
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
\**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved
hyperparameter values.

Expand Down Expand Up @@ -1778,6 +1784,7 @@ def load_from_checkpoint(
map_location,
hparams_file,
strict,
weights_only,
**kwargs,
)
return cast(Self, loaded)
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@ def _load_from_checkpoint(
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[_PATH] = None,
strict: Optional[bool] = None,
weights_only: Optional[bool] = None,
**kwargs: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
map_location = map_location or _default_map_location

with pl_legacy_patch():
checkpoint = pl_load(checkpoint_path, map_location=map_location)
checkpoint = pl_load(checkpoint_path, map_location=map_location, weights_only=weights_only)

# convert legacy checkpoints to the new format
checkpoint = _pl_migrate_checkpoint(
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,12 +659,12 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op
)

@override
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]:
if self.load_full_weights and self.zero_stage_3:
# Broadcast to ensure we load from the rank 0 checkpoint
# This doesn't have to be the case when using deepspeed sharded checkpointing
checkpoint_path = self.broadcast(checkpoint_path)
return super().load_checkpoint(checkpoint_path)
return super().load_checkpoint(checkpoint_path, weights_only)

_validate_checkpoint_directory(checkpoint_path)

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def save_checkpoint(
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")

@override
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]:
# broadcast the path from rank 0 to ensure all the states are loaded from a common path
path = Path(self.broadcast(checkpoint_path))

Expand Down Expand Up @@ -624,7 +624,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
optim.load_state_dict(flattened_osd)

# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME)
metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only)
return metadata

if _is_full_checkpoint(path):
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def save_checkpoint(
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)

@override
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]:
# broadcast the path from rank 0 to ensure all the states are loaded from a common path
path = Path(self.broadcast(checkpoint_path))
state = {
Expand All @@ -342,6 +342,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
state=state,
strict=self.lightning_module.strict_loading,
optimizer_states_from_list=True,
weights_only=weights_only,
)

def _setup_distributed(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
"""Returns the pure LightningModule without potential wrappers."""
return self._lightning_module

def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]:
torch.cuda.empty_cache()
return self.checkpoint_io.load_checkpoint(checkpoint_path)
return self.checkpoint_io.load_checkpoint(checkpoint_path, weights_only=weights_only)

def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
assert self.lightning_module is not None
Expand Down
Loading
Loading