Skip to content

Commit d1a3d88

Browse files
authored
add warm load (PaddlePaddle#11029)
1 parent adc2f36 commit d1a3d88

File tree

4 files changed

+298
-5
lines changed

4 files changed

+298
-5
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@
189189
nested_numpify,
190190
nested_truncate,
191191
)
192+
from .utils.load_utils import load_paddle_model_from_safetensors
192193
from .utils.sharding_io import ShardingIO
193194

194195
DEFAULT_CALLBACKS = [DefaultFlowCallback]
@@ -1108,6 +1109,13 @@ def _inner_training_loop(
11081109
if self.args.ignore_data_skip:
11091110
self.timers and self.timers("read-data").start()
11101111

1112+
if self.args.hf_ckpt_dir is not None:
1113+
print("Start loading the Hugging Face model with warm start")
1114+
weight_map_path = os.path.join(self.args.hf_ckpt_dir, "model.safetensors.index.json")
1115+
ckpt_pre = self.args.hf_ckpt_dir
1116+
1117+
load_paddle_model_from_safetensors(model, weight_map_path, ckpt_pre, verbose=True)
1118+
11111119
for epoch in range(epochs_trained, num_train_epochs):
11121120
if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance(
11131121
train_dataloader.batch_sampler, DistributedBatchSampler
@@ -1343,8 +1351,28 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
13431351
f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}"
13441352
)
13451353
elif isinstance(self.optimizer, HybridParallelOptimizer):
1354+
# print("hack for moe grad")
1355+
# for p in parameters_list:
1356+
# if getattr(p, 'is_moe_param', False):
1357+
# if p.grad is not None:
1358+
# # print(p.name, p.grad)
1359+
# p.grad /= 8
1360+
# if p.main_grad is not None:
1361+
# # print(p.name, p.main_grad)
1362+
# p.main_grad /= 8
1363+
13461364
self.optimizer._step(parameters_list)
13471365
else:
1366+
# print("hack for moe gradr")
1367+
# for p in parameters_list:
1368+
# if getattr(p, 'is_moe_param', False):
1369+
# if p.grad is not None:
1370+
# print(p.name, p.grad)
1371+
# p.grad /= 4
1372+
# if p.main_grad is not None:
1373+
# print(p.name, p.main_grad)
1374+
# p.main_grad /= 4
1375+
13481376
self.optimizer.step()
13491377

13501378
if self.args.offload_optim:

paddlenlp/trainer/training_args.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,12 +1081,17 @@ class TrainingArguments:
10811081
nccl_comm_group_config: Optional[str] = field(
10821082
default=None, metadata={"help": "NCCL中通信组的细粒度控制的配置文件路径, 默认值为None, 代表不启用此项配置"}
10831083
)
1084-
1084+
10851085
pre_alloc_memory: int = field(
10861086
default=0,
10871087
metadata={"help": "pre allocate memory size GB"},
10881088
)
10891089

1090+
hf_ckpt_dir: Optional[str] = field(
1091+
default=None,
1092+
metadata={"help": "huggingface checkpoint dir"},
1093+
)
1094+
10901095
def __post_init__(self):
10911096
world_size = paddle.distributed.get_world_size()
10921097
if in_auto_parallel_align_mode():
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import re
17+
from collections import defaultdict
18+
from typing import List, Optional
19+
20+
import paddle
21+
from paddle.distributed import fleet
22+
from safetensors import safe_open
23+
24+
# develop: "_layers.<idx>.<rest>"
25+
_LAYER_RE = re.compile(r"^_layers\.(\d+)(?:\.(.*))?$")
26+
_EXPERT_W1_RE = re.compile(r"^mlp\.experts\.(\d+)\.w1(?:\.weight)?$")
27+
_EXPERT_W2_RE = re.compile(r"^mlp\.experts\.(\d+)\.w2(?:\.weight)?$")
28+
29+
custom_name_map = {
30+
"mlp.router.weight": "mlp.gate.weight",
31+
"mlp.router.e_score_correction_bias": "mlp.gate.e_score_correction_bias",
32+
}
33+
34+
35+
def _layers_match(name: str):
36+
return _LAYER_RE.match(name)
37+
38+
39+
def simple_safe_call(model, method_name, *args, **kwargs):
40+
if hasattr(model, method_name):
41+
return getattr(model, method_name)(*args, **kwargs)
42+
if hasattr(model, "_layers") and hasattr(model._layers, method_name):
43+
return getattr(model._layers, method_name)(*args, **kwargs)
44+
raise AttributeError(f"{type(model).__name__} (or its wrapper) has no method {method_name}")
45+
46+
47+
def add_prefix_to_keys(d, prefix):
48+
print("Input dict:", d)
49+
50+
mappings = {}
51+
for key, value in d.items():
52+
if key == "embed_tokens.weight":
53+
new_key = "_layers.0.embed_tokens.weight"
54+
elif key == "lm_head.weight":
55+
new_key = "_layers.64.weight"
56+
else:
57+
new_key = f"{prefix}{key}"
58+
mappings[new_key] = value
59+
return mappings
60+
61+
62+
def _get_hf_prefix_develop(idx: int) -> str:
63+
if idx == 0:
64+
return "model" # embedding
65+
if idx == 63:
66+
return "model" # final norm
67+
if idx == 64:
68+
return "lm_head" # lm_head
69+
return f"model.layers.{idx - 1}" # decoder layer
70+
71+
72+
def _handle_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]:
73+
if m := _EXPERT_W1_RE.match(rest):
74+
expert_id = int(m.group(1))
75+
return [
76+
f"{hf_prefix}.mlp.experts.{expert_id}.gate_proj.weight",
77+
f"{hf_prefix}.mlp.experts.{expert_id}.up_proj.weight",
78+
]
79+
if m := _EXPERT_W2_RE.match(rest):
80+
expert_id = int(m.group(1))
81+
return [
82+
f"{hf_prefix}.mlp.experts.{expert_id}.down_proj.weight",
83+
]
84+
return None
85+
86+
87+
def _handle_mlp_weights(hf_prefix: str, rest: str) -> Optional[List[str]]:
88+
if rest == "mlp.w1":
89+
return [
90+
f"{hf_prefix}.mlp.gate_proj.weight",
91+
f"{hf_prefix}.mlp.up_proj.weight",
92+
]
93+
if rest == "mlp.w2":
94+
return [
95+
f"{hf_prefix}.mlp.down_proj.weight",
96+
]
97+
return None
98+
99+
100+
def paddle_name_to_hf_names(paddle_name: str) -> List[str]:
101+
"""
102+
Mapping Function for Paddle Parameter Names to Hugging Face Names
103+
"""
104+
m = _layers_match(paddle_name)
105+
if not m:
106+
return []
107+
idx = int(m.group(1))
108+
rest = m.group(2) or ""
109+
110+
hf_prefix = _get_hf_prefix_develop(idx)
111+
112+
# 专项重命名
113+
if rest in custom_name_map:
114+
return [f"{hf_prefix}.{custom_name_map[rest]}"]
115+
116+
# 历史专家
117+
if expert_names := _handle_expert_weights(hf_prefix, rest):
118+
return expert_names
119+
120+
# 历史mlp
121+
if mlp_names := _handle_mlp_weights(hf_prefix, rest):
122+
return mlp_names
123+
124+
return [f"{hf_prefix}.{rest}"] if rest else [hf_prefix]
125+
126+
127+
def prepare_tensor(tensor, pd_param, tensor_parallel_mappings, mp_degree, dst_shape):
128+
"""
129+
Converting weight tensors to match the target model’s shape involves
130+
automatically adjusting for transposing, concatenating, and slicing by columns or lengths.
131+
"""
132+
133+
if isinstance(tensor, list):
134+
tensor = paddle.concat(
135+
[
136+
paddle.transpose(tensor[0], perm=[1, 0]).contiguous(),
137+
paddle.transpose(tensor[1], perm=[1, 0]).contiguous(),
138+
],
139+
axis=-1,
140+
)
141+
# match for transpose
142+
if len(tensor.shape) == 2:
143+
if (tensor.shape[0] == dst_shape[1] or tensor.shape[1] == dst_shape[0]) and tensor.shape != dst_shape:
144+
tensor = paddle.transpose(tensor, perm=[1, 0]).contiguous()
145+
print(f"after transpose get hf tensor shape {tensor.shape}, paddle shape {dst_shape}")
146+
147+
if mp_degree > 1 and pd_param in tensor_parallel_mappings:
148+
tensor = tensor_parallel_mappings[pd_param](tensor)
149+
if tensor.shape == dst_shape:
150+
return tensor
151+
raise ValueError(f"Unexpected tensor shape: got {tensor.shape}, want {dst_shape}")
152+
153+
154+
def load_paddle_model_from_safetensors(
155+
model,
156+
weight_map_path: str,
157+
ckpt_pre: str,
158+
verbose: bool = True,
159+
):
160+
"""
161+
Load safetensors into a Paddle model using the weight mappings outlined in index.json.
162+
"""
163+
164+
tensor_parallel_mappings = {}
165+
mp_degree = fleet.get_hybrid_communicate_group().get_model_parallel_world_size()
166+
print("fuck mp degree!!!!!!!!!", mp_degree)
167+
168+
if mp_degree > 1:
169+
print("load with mp_degree:", mp_degree)
170+
tensor_parallel_mappings = simple_safe_call(model, "get_tensor_parallel_mappings", is_split=True)
171+
tensor_parallel_mappings = add_prefix_to_keys(tensor_parallel_mappings, "_")
172+
173+
for k, v in tensor_parallel_mappings.items():
174+
print("tensor_parallel_mappings:", k, v)
175+
176+
with open(weight_map_path, "r") as f:
177+
weight_map = json.load(f)["weight_map"]
178+
179+
required_files = set()
180+
file_to_pd_param_name = defaultdict(list)
181+
pd_param_name_to_file = defaultdict(list)
182+
183+
for pd_name, _ in model.named_parameters():
184+
hf_names = paddle_name_to_hf_names(pd_name)
185+
if verbose:
186+
print(f"paddle_name_to_hf_names: {pd_name} -> {hf_names}")
187+
if not hf_names:
188+
if verbose:
189+
print(f"Warning: {pd_name} can not be mapped")
190+
continue
191+
for i, hf_name in enumerate(hf_names):
192+
if hf_name in weight_map:
193+
filename = weight_map[hf_name]
194+
required_files.add(filename)
195+
file_to_pd_param_name[filename].append(pd_name)
196+
if filename not in pd_param_name_to_file[pd_name]:
197+
pd_param_name_to_file[pd_name].append(filename)
198+
else:
199+
if verbose:
200+
print(f"Warning: {pd_name} -> {hf_name} not found in weight map")
201+
202+
check_list = []
203+
if verbose:
204+
print("---- start load param ----")
205+
for key, value in tensor_parallel_mappings.items():
206+
print(key, value)
207+
for filename in required_files:
208+
try:
209+
with safe_open(ckpt_pre + filename, framework="paddle", device="cpu") as f:
210+
pd_params = file_to_pd_param_name[filename]
211+
for pd_param in pd_params:
212+
if pd_param in check_list:
213+
continue
214+
if verbose:
215+
print("load for pd_param:", pd_param)
216+
hf_names = paddle_name_to_hf_names(pd_param)
217+
if not hf_names:
218+
continue
219+
if len(hf_names) == 1:
220+
tensor = f.get_tensor(hf_names[0])
221+
value = prepare_tensor(
222+
tensor, pd_param, tensor_parallel_mappings, mp_degree, model.state_dict()[pd_param].shape
223+
)
224+
225+
model.state_dict()[pd_param].set_value(paddle.cast(value, model.state_dict()[pd_param].dtype))
226+
else:
227+
files = pd_param_name_to_file[pd_param]
228+
if len(files) == 1:
229+
tensor0 = f.get_tensor(hf_names[0])
230+
tensor1 = f.get_tensor(hf_names[1])
231+
else:
232+
if weight_map[hf_names[0]] == filename:
233+
tensor0 = f.get_tensor(hf_names[0])
234+
with safe_open(
235+
ckpt_pre + weight_map[hf_names[1]], framework="paddle", device="cpu"
236+
) as f2:
237+
tensor1 = f2.get_tensor(hf_names[1])
238+
else:
239+
with safe_open(
240+
ckpt_pre + weight_map[hf_names[0]], framework="paddle", device="cpu"
241+
) as f2:
242+
tensor0 = f2.get_tensor(hf_names[0])
243+
tensor1 = f.get_tensor(hf_names[1])
244+
value = prepare_tensor(
245+
[tensor0, tensor1],
246+
pd_param,
247+
tensor_parallel_mappings,
248+
mp_degree,
249+
model.state_dict()[pd_param].shape,
250+
)
251+
model.state_dict()[pd_param].set_value(value)
252+
check_list.append(pd_param)
253+
except Exception as e:
254+
print(f"Error loading {filename}: {str(e)}")
255+
raise
256+
257+
if verbose:
258+
print("All parameters loaded.")

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def forward(self, args):
170170
batch_size, seq_length, _ = inputs_embeds.shape
171171

172172
if self.sequence_parallel:
173-
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H]
173+
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H]
174174
# [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
175175
# inputs_embeds = paddle.reshape(inputs_embeds, [-1, inputs_embeds.shape[-1]])
176176
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
@@ -185,7 +185,7 @@ def forward(self, args):
185185
axis=1,
186186
)
187187
if self.sequence_parallel:
188-
inputs_embeds_mtp = paddle.transpose(inputs_embeds_mtp, [1, 0, 2]) # [B, S, H] --> [S, B, H]
188+
inputs_embeds_mtp = paddle.transpose(inputs_embeds_mtp, [1, 0, 2]) # [B, S, H] --> [S, B, H]
189189
# inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]])
190190
inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp)
191191
embeds_res.append(inputs_embeds_mtp)
@@ -197,7 +197,7 @@ def forward(self, args):
197197
return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids)
198198
else:
199199
if self.sequence_parallel:
200-
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H]
200+
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H]
201201
# inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]])
202202
inputs_embeds = ScatterOp.apply(inputs_embeds)
203203
return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids)
@@ -270,7 +270,6 @@ def forward(self, args):
270270
class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer):
271271
def forward(self, args):
272272
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
273-
274273
hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1)
275274
hidden_states_main_model = hidden_states_list[0]
276275
inputs_embeds_cur_depth_list = hidden_states_list[1:]
@@ -525,3 +524,6 @@ def get_hcg():
525524

526525
def get_loss_fn(self, config):
527526
return DeepseekV2PretrainingCriterionPipe(config)
527+
528+
def get_tensor_parallel_mappings(self, is_split=True):
529+
return type(self)._get_tensor_parallel_mappings(self.config, is_split)

0 commit comments

Comments
 (0)