Skip to content

Commit 34abba7

Browse files
xingmingyyjpkuzyc
authored andcommitted
1 parent d1a3d88 commit 34abba7

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

paddlenlp/trainer/utils/ckpt_converter.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,17 @@
1919
from typing import List, Union
2020

2121
import paddle
22-
from paddle.distributed.checkpoint.load_state_dict import (
22+
from paddle.distributed.fleet.utils.log_util import logger
23+
from paddle.distributed.flex_checkpoint.dcp.load_state_dict import (
2324
_load_state_dict,
2425
get_rank_to_read_files,
2526
)
26-
from paddle.distributed.checkpoint.metadata import (
27+
from paddle.distributed.flex_checkpoint.dcp.metadata import (
2728
LocalTensorIndex,
2829
LocalTensorMetadata,
2930
Metadata,
3031
)
31-
from paddle.distributed.checkpoint.utils import flatten_state_dict
32-
from paddle.distributed.fleet.utils.log_util import logger
32+
from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict
3333

3434
MODEL_WEIGHT_SUFFIX = ".pdparams"
3535
OPTIMIZER_WEIGHT_SUFFIX = ".pdopt"
@@ -206,7 +206,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
206206
global_offset = [0] * self.tp_degree
207207
for item in shard_info:
208208
tp_rank = item[0]["tp_rank"]
209-
state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank)
209+
state_name_with_tp_rank = state_name + "_tp" + f"{tp_rank:02d}"
210210
local_tensor_meta_data = LocalTensorMetadata((global_offset[tp_rank],), item[1], item[2])
211211
local_tensor_index = LocalTensorIndex(state_name_with_tp_rank, (global_offset[tp_rank],))
212212
global_offset[tp_rank] += item[1][0]
@@ -225,7 +225,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
225225
renamed_state_dict = {}
226226
(tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name)
227227
for state_name, state_value in state_dict.items():
228-
state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank)
228+
state_name_with_tp_rank = state_name + "_tp" + f"{tp_rank:02d}"
229229
renamed_state_dict[state_name_with_tp_rank] = state_value
230230

231231
source_state_dict_for_merge_sharding[file_name] = renamed_state_dict
@@ -235,7 +235,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
235235
sharding_metas_keys = []
236236
for i in range(self.tp_degree):
237237
for j in range(self.pp_degree):
238-
sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j))
238+
sharding_metas_keys.append(f"tp{i:02d}_pp{j:02d}")
239239
for key in sharding_metas_keys:
240240
param_meta = self.model_meta["sharding_metas"][key]["param_meta"]
241241
for param_name, param_shape_and_dtype in param_meta.items():
@@ -253,7 +253,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
253253
all_param_meta = {}
254254
for i in range(self.tp_degree):
255255
for j in range(self.pp_degree):
256-
key = "tp{:02d}_pp{:02d}".format(i, j)
256+
key = f"tp{i:02d}_pp{j:02d}"
257257
param_meta = self.model_meta["sharding_metas"][key]["param_meta"]
258258
for param_name, param_shape_and_dtype in param_meta.items():
259259
all_param_meta[param_name] = param_shape_and_dtype
@@ -269,7 +269,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
269269
with paddle.base.dygraph.guard(place=paddle.CPUPlace()):
270270
for key in cur_rank_need_load_model_state_keys:
271271
for tp_rank in range(self.tp_degree):
272-
tp_rank_suffix = "_tp{:02d}".format(tp_rank)
272+
tp_rank_suffix = f"_tp{tp_rank:02d}"
273273
optimizer_state_dict[key + ".moment1" + tp_rank_suffix] = paddle.zeros(
274274
(param_flattened_shapes[key],), "float32"
275275
)
@@ -353,7 +353,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
353353
else:
354354
concat_optimier_state_dict[opt_state_name_removed_tp_rank] = tp_tensors[0]
355355

356-
fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp"
356+
fake_file_name = f"{self.cur_rank:02d}" + ".distcp"
357357
local_tensor_meta_data = {}
358358
local_tensor_index = {}
359359
for k, v in concat_optimier_state_dict.items():
@@ -472,7 +472,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
472472
reshaped_v = v.reshape(shape)
473473
target_state_dict[k] = reshaped_v
474474

475-
fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp"
475+
fake_file_name = f"{self.cur_rank:02d}" + ".distcp"
476476
local_tensor_meta_data = {}
477477
local_tensor_index = {}
478478
for k, v in target_state_dict.items():
@@ -911,7 +911,7 @@ def rename_using_model_meta(self, file_name):
911911
self.model_meta = json.load(file)
912912

913913
(tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name)
914-
dist_strategy_key = "tp" + "{:02d}".format(tp_rank) + "_" + "pp" + "{:02d}".format(pp_rank)
914+
dist_strategy_key = "tp" + f"{tp_rank:02d}" + "_" + "pp" + f"{pp_rank:02d}"
915915
# Map model weight names to their corresponding names of master_weights in the optimizer state.
916916
if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX):
917917
structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"]

0 commit comments

Comments
 (0)