Skip to content

[Deepspeed] [performance] inefficient load with from_pretrained w/ zero3 #12273

@stas00

Description

@stas00

🚀 Feature request

Currently under Deepspeed stage3 with from_pretrained we:

a. loop over each sub-module in zero.Init

  1. init the sub-module
  2. shard and scatter the shards

b. then to load pre-trained weights we loop over each sub-module:

  1. gather the shards
  2. load_state_dict for the one layer layer
  3. shard and scatter the shards

c. any sub-module params that weren't in the pretrained state_dict

  1. run the postponed module_init as it was done in Pytorch - Lazy initialization of models #11471
  2. shard and scatter the shards XXX: I actually don't think deepspeed.zero.GatheredParameters was handled here. so these params don't get ZeRO'ed - need to fix that [Deepspeed zero3] lazy weights init  #12272

Because we unnecessarily do scatter/gather/scatter, this takes much longer than just:

a. init the modules w/o allocating any storage as it has been implemented in pt-1.9.0/1.9.1 https://pytorch.org/tutorials/prototype/skip_param_init.html#implementation-details

b. for each sub-module with pretrained weights

  1. load_state_dict
  2. shard and scatter the shards

c. any sub-module params that weren't in the pretrained state_dict

  1. materialize and module_init
  2. shard and scatter the shards

Solving this will most likely require support from Deepspeed, deepspeedai/DeepSpeed#1142 or perhaps we can just try to remove zero.Init if the weights aren't materialized during model creation. So the very first sharding will get postponed to the load_state_dict stage (and module_init for the sub-modules that don't have pre-trained weights).

Metadata

Metadata

Assignees

Labels

DeepSpeedWIPLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions