19
19
from typing import List , Union
20
20
21
21
import paddle
22
- from paddle .distributed .checkpoint .load_state_dict import (
22
+ from paddle .distributed .fleet .utils .log_util import logger
23
+ from paddle .distributed .flex_checkpoint .dcp .load_state_dict import (
23
24
_load_state_dict ,
24
25
get_rank_to_read_files ,
25
26
)
26
- from paddle .distributed .checkpoint .metadata import (
27
+ from paddle .distributed .flex_checkpoint . dcp .metadata import (
27
28
LocalTensorIndex ,
28
29
LocalTensorMetadata ,
29
30
Metadata ,
30
31
)
31
- from paddle .distributed .checkpoint .utils import flatten_state_dict
32
- from paddle .distributed .fleet .utils .log_util import logger
32
+ from paddle .distributed .flex_checkpoint .dcp .utils import flatten_state_dict
33
33
34
34
MODEL_WEIGHT_SUFFIX = ".pdparams"
35
35
OPTIMIZER_WEIGHT_SUFFIX = ".pdopt"
@@ -206,7 +206,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
206
206
global_offset = [0 ] * self .tp_degree
207
207
for item in shard_info :
208
208
tp_rank = item [0 ]["tp_rank" ]
209
- state_name_with_tp_rank = state_name + "_tp" + "{ :02d}". format ( tp_rank )
209
+ state_name_with_tp_rank = state_name + "_tp" + f" { tp_rank :02d} "
210
210
local_tensor_meta_data = LocalTensorMetadata ((global_offset [tp_rank ],), item [1 ], item [2 ])
211
211
local_tensor_index = LocalTensorIndex (state_name_with_tp_rank , (global_offset [tp_rank ],))
212
212
global_offset [tp_rank ] += item [1 ][0 ]
@@ -225,7 +225,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
225
225
renamed_state_dict = {}
226
226
(tp_rank , pp_rank , sharding_rank ) = self .get_distribution_rank_from_file_name (file_name )
227
227
for state_name , state_value in state_dict .items ():
228
- state_name_with_tp_rank = state_name + "_tp" + "{ :02d}". format ( tp_rank )
228
+ state_name_with_tp_rank = state_name + "_tp" + f" { tp_rank :02d} "
229
229
renamed_state_dict [state_name_with_tp_rank ] = state_value
230
230
231
231
source_state_dict_for_merge_sharding [file_name ] = renamed_state_dict
@@ -235,7 +235,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
235
235
sharding_metas_keys = []
236
236
for i in range (self .tp_degree ):
237
237
for j in range (self .pp_degree ):
238
- sharding_metas_keys .append ("tp{:02d}_pp{:02d}" . format ( i , j ) )
238
+ sharding_metas_keys .append (f "tp{ i :02d} _pp{ j :02d} " )
239
239
for key in sharding_metas_keys :
240
240
param_meta = self .model_meta ["sharding_metas" ][key ]["param_meta" ]
241
241
for param_name , param_shape_and_dtype in param_meta .items ():
@@ -253,7 +253,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
253
253
all_param_meta = {}
254
254
for i in range (self .tp_degree ):
255
255
for j in range (self .pp_degree ):
256
- key = "tp{:02d}_pp{:02d}" . format ( i , j )
256
+ key = f "tp{ i :02d} _pp{ j :02d} "
257
257
param_meta = self .model_meta ["sharding_metas" ][key ]["param_meta" ]
258
258
for param_name , param_shape_and_dtype in param_meta .items ():
259
259
all_param_meta [param_name ] = param_shape_and_dtype
@@ -269,7 +269,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
269
269
with paddle .base .dygraph .guard (place = paddle .CPUPlace ()):
270
270
for key in cur_rank_need_load_model_state_keys :
271
271
for tp_rank in range (self .tp_degree ):
272
- tp_rank_suffix = "_tp{:02d}" . format ( tp_rank )
272
+ tp_rank_suffix = f "_tp{ tp_rank :02d} "
273
273
optimizer_state_dict [key + ".moment1" + tp_rank_suffix ] = paddle .zeros (
274
274
(param_flattened_shapes [key ],), "float32"
275
275
)
@@ -353,7 +353,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
353
353
else :
354
354
concat_optimier_state_dict [opt_state_name_removed_tp_rank ] = tp_tensors [0 ]
355
355
356
- fake_file_name = "{ :02d}". format ( self . cur_rank ) + ".distcp"
356
+ fake_file_name = f" { self . cur_rank :02d} " + ".distcp"
357
357
local_tensor_meta_data = {}
358
358
local_tensor_index = {}
359
359
for k , v in concat_optimier_state_dict .items ():
@@ -472,7 +472,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
472
472
reshaped_v = v .reshape (shape )
473
473
target_state_dict [k ] = reshaped_v
474
474
475
- fake_file_name = "{ :02d}". format ( self . cur_rank ) + ".distcp"
475
+ fake_file_name = f" { self . cur_rank :02d} " + ".distcp"
476
476
local_tensor_meta_data = {}
477
477
local_tensor_index = {}
478
478
for k , v in target_state_dict .items ():
@@ -911,7 +911,7 @@ def rename_using_model_meta(self, file_name):
911
911
self .model_meta = json .load (file )
912
912
913
913
(tp_rank , pp_rank , sharding_rank ) = self .get_distribution_rank_from_file_name (file_name )
914
- dist_strategy_key = "tp" + "{ :02d}". format ( tp_rank ) + "_" + "pp" + "{ :02d}". format ( pp_rank )
914
+ dist_strategy_key = "tp" + f" { tp_rank :02d} " + "_" + "pp" + f" { pp_rank :02d} "
915
915
# Map model weight names to their corresponding names of master_weights in the optimizer state.
916
916
if file_name .endswith (OPTIMIZER_WEIGHT_SUFFIX ):
917
917
structure_name_mapping = self .model_meta ["sharding_metas" ][dist_strategy_key ]["structure_name_mapping" ]
0 commit comments