-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Closed
Description
System Info
accelerate version==1.7.0, OS==Linux, python version==3.10.16,
"deepspeed_config": {
"zero_stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},q
}
Information
- The official example scripts
- My own modified scripts
Tasks
- One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
) - My own task or dataset (give details below)
Reproduction
After preparing the optimizer with accelerator.prepare()
, I'm unable to properly save and load its state:
- When trying to save the prepared optimizer directly:
optimizer = torch.optim.Adam(model.parameters())
optimizer = accelerator.prepare(model, optimizer)[1]
# Saving
torch.save({
'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pt')
# Loading
checkpoint = torch.load('checkpoint.pt')
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Fails
This fails with:
KeyError: 'param_groups'
2.I've tried looking at the saved optimizer state:
ckpt = torch.load('checkpoint.pt', map_location='cpu')
print(ckpt['optimizer_state_dict'].keys())
Shows DeepSpeed format keys:
dict_keys(['loss_scaler', 'dynamic_loss_scale', 'overflow', 'clip_grad', 'base_optimizer_state', 'single_partition_of_fp32_groups', 'zero_stage', 'group_paddings', 'partition_count', 'ds_version', 'param_slice_mappings'])
So I tried to load from base_optimizer_state:
optimizer.load_state_dict(checkpoint['optimizer_state_dict']['base_optimizer_state'])
Fails with:
ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group
3.When trying to unwrap the optimizer first:
unwrapped_optimizer = accelerator.unwrap_model(optimizer) # Fails
This fails with:
AttributeError: 'DeepSpeedOptimizerWrapper' object has no attribute '_modules'
Expected behavior
Should be able to save and load optimizer states when using Accelerate, similar to regular PyTorch training.
Metadata
Metadata
Assignees
Labels
No labels