Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions paddlenlp/trainer/utils/ckpt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@
from typing import List, Union

import paddle
from paddle.distributed.checkpoint.load_state_dict import (
from paddle.distributed.fleet.utils.log_util import logger
from paddle.distributed.flex_checkpoint.dcp.load_state_dict import (
_load_state_dict,
get_rank_to_read_files,
)
from paddle.distributed.checkpoint.metadata import (
from paddle.distributed.flex_checkpoint.dcp.metadata import (
LocalTensorIndex,
LocalTensorMetadata,
Metadata,
)
from paddle.distributed.checkpoint.utils import flatten_state_dict
from paddle.distributed.fleet.utils.log_util import logger
from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict

MODEL_WEIGHT_SUFFIX = ".pdparams"
OPTIMIZER_WEIGHT_SUFFIX = ".pdopt"
Expand Down Expand Up @@ -206,7 +206,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
global_offset = [0] * self.tp_degree
for item in shard_info:
tp_rank = item[0]["tp_rank"]
state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank)
state_name_with_tp_rank = state_name + "_tp" + f"{tp_rank:02d}"
local_tensor_meta_data = LocalTensorMetadata((global_offset[tp_rank],), item[1], item[2])
local_tensor_index = LocalTensorIndex(state_name_with_tp_rank, (global_offset[tp_rank],))
global_offset[tp_rank] += item[1][0]
Expand All @@ -225,7 +225,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
renamed_state_dict = {}
(tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name)
for state_name, state_value in state_dict.items():
state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank)
state_name_with_tp_rank = state_name + "_tp" + f"{tp_rank:02d}"
renamed_state_dict[state_name_with_tp_rank] = state_value

source_state_dict_for_merge_sharding[file_name] = renamed_state_dict
Expand All @@ -235,7 +235,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
sharding_metas_keys = []
for i in range(self.tp_degree):
for j in range(self.pp_degree):
sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j))
sharding_metas_keys.append(f"tp{i:02d}_pp{j:02d}")
for key in sharding_metas_keys:
param_meta = self.model_meta["sharding_metas"][key]["param_meta"]
for param_name, param_shape_and_dtype in param_meta.items():
Expand All @@ -253,7 +253,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
all_param_meta = {}
for i in range(self.tp_degree):
for j in range(self.pp_degree):
key = "tp{:02d}_pp{:02d}".format(i, j)
key = f"tp{i:02d}_pp{j:02d}"
param_meta = self.model_meta["sharding_metas"][key]["param_meta"]
for param_name, param_shape_and_dtype in param_meta.items():
all_param_meta[param_name] = param_shape_and_dtype
Expand All @@ -269,7 +269,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
with paddle.base.dygraph.guard(place=paddle.CPUPlace()):
for key in cur_rank_need_load_model_state_keys:
for tp_rank in range(self.tp_degree):
tp_rank_suffix = "_tp{:02d}".format(tp_rank)
tp_rank_suffix = f"_tp{tp_rank:02d}"
optimizer_state_dict[key + ".moment1" + tp_rank_suffix] = paddle.zeros(
(param_flattened_shapes[key],), "float32"
)
Expand Down Expand Up @@ -353,7 +353,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
else:
concat_optimier_state_dict[opt_state_name_removed_tp_rank] = tp_tensors[0]

fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp"
fake_file_name = f"{self.cur_rank:02d}" + ".distcp"
local_tensor_meta_data = {}
local_tensor_index = {}
for k, v in concat_optimier_state_dict.items():
Expand Down Expand Up @@ -472,7 +472,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
reshaped_v = v.reshape(shape)
target_state_dict[k] = reshaped_v

fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp"
fake_file_name = f"{self.cur_rank:02d}" + ".distcp"
local_tensor_meta_data = {}
local_tensor_index = {}
for k, v in target_state_dict.items():
Expand Down Expand Up @@ -911,7 +911,7 @@ def rename_using_model_meta(self, file_name):
self.model_meta = json.load(file)

(tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name)
dist_strategy_key = "tp" + "{:02d}".format(tp_rank) + "_" + "pp" + "{:02d}".format(pp_rank)
dist_strategy_key = "tp" + f"{tp_rank:02d}" + "_" + "pp" + f"{pp_rank:02d}"
# Map model weight names to their corresponding names of master_weights in the optimizer state.
if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX):
structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"]
Expand Down
Loading