Skip to content

Commit 11ce634

Browse files
allowlist WeightWithDynamicFloat8CastTensor for deserialization for checkpointing (#2573)
1 parent 3460951 commit 11ce634

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchao/float8/fsdp_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,7 @@ def fsdp_post_all_gather(
266266
self._linear_mm_config,
267267
gemm_input_role=GemmInputRole.WEIGHT,
268268
), (data,)
269+
270+
271+
# Needed to allowlist this subclass for deserialization used for restoring checkpoints.
272+
torch.serialization.add_safe_globals([WeightWithDynamicFloat8CastTensor])

0 commit comments

Comments
 (0)