-
Notifications
You must be signed in to change notification settings - Fork 30.2k
Description
🚀 Feature request
Currently under Deepspeed stage3 with from_pretrained
we:
a. loop over each sub-module in zero.Init
- init the sub-module
- shard and scatter the shards
b. then to load pre-trained weights we loop over each sub-module:
- gather the shards
load_state_dict
for the one layer layer- shard and scatter the shards
c. any sub-module params that weren't in the pretrained state_dict
- run the postponed
module_init
as it was done in Pytorch - Lazy initialization of models #11471 - shard and scatter the shards XXX: I actually don't think
deepspeed.zero.GatheredParameters
was handled here. so these params don't get ZeRO'ed - need to fix that [Deepspeed zero3] lazy weights init #12272
Because we unnecessarily do scatter/gather/scatter, this takes much longer than just:
a. init the modules w/o allocating any storage as it has been implemented in pt-1.9.0/1.9.1 https://pytorch.org/tutorials/prototype/skip_param_init.html#implementation-details
b. for each sub-module with pretrained weights
- load_state_dict
- shard and scatter the shards
c. any sub-module params that weren't in the pretrained state_dict
- materialize and module_init
- shard and scatter the shards
Solving this will most likely require support from Deepspeed, deepspeedai/DeepSpeed#1142 or perhaps we can just try to remove zero.Init
if the weights aren't materialized during model creation. So the very first sharding will get postponed to the load_state_dict
stage (and module_init
for the sub-modules that don't have pre-trained weights).