|
| 1 | +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import json |
| 16 | +import re |
| 17 | +from collections import defaultdict |
| 18 | +from typing import List, Optional |
| 19 | + |
| 20 | +import paddle |
| 21 | +from paddle.distributed import fleet |
| 22 | +from safetensors import safe_open |
| 23 | + |
| 24 | +# develop: "_layers.<idx>.<rest>" |
| 25 | +_LAYER_RE = re.compile(r"^_layers\.(\d+)(?:\.(.*))?$") |
| 26 | +_EXPERT_W1_RE = re.compile(r"^mlp\.experts\.(\d+)\.w1(?:\.weight)?$") |
| 27 | +_EXPERT_W2_RE = re.compile(r"^mlp\.experts\.(\d+)\.w2(?:\.weight)?$") |
| 28 | + |
| 29 | +custom_name_map = { |
| 30 | + "mlp.router.weight": "mlp.gate.weight", |
| 31 | + "mlp.router.e_score_correction_bias": "mlp.gate.e_score_correction_bias", |
| 32 | +} |
| 33 | + |
| 34 | + |
| 35 | +def _layers_match(name: str): |
| 36 | + return _LAYER_RE.match(name) |
| 37 | + |
| 38 | + |
| 39 | +def simple_safe_call(model, method_name, *args, **kwargs): |
| 40 | + if hasattr(model, method_name): |
| 41 | + return getattr(model, method_name)(*args, **kwargs) |
| 42 | + if hasattr(model, "_layers") and hasattr(model._layers, method_name): |
| 43 | + return getattr(model._layers, method_name)(*args, **kwargs) |
| 44 | + raise AttributeError(f"{type(model).__name__} (or its wrapper) has no method {method_name}") |
| 45 | + |
| 46 | + |
| 47 | +def add_prefix_to_keys(d, prefix): |
| 48 | + print("Input dict:", d) |
| 49 | + |
| 50 | + mappings = {} |
| 51 | + for key, value in d.items(): |
| 52 | + if key == "embed_tokens.weight": |
| 53 | + new_key = "_layers.0.embed_tokens.weight" |
| 54 | + elif key == "lm_head.weight": |
| 55 | + new_key = "_layers.64.weight" |
| 56 | + else: |
| 57 | + new_key = f"{prefix}{key}" |
| 58 | + mappings[new_key] = value |
| 59 | + return mappings |
| 60 | + |
| 61 | + |
| 62 | +def _get_hf_prefix_develop(idx: int) -> str: |
| 63 | + if idx == 0: |
| 64 | + return "model" # embedding |
| 65 | + if idx == 63: |
| 66 | + return "model" # final norm |
| 67 | + if idx == 64: |
| 68 | + return "lm_head" # lm_head |
| 69 | + return f"model.layers.{idx - 1}" # decoder layer |
| 70 | + |
| 71 | + |
| 72 | +def _handle_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: |
| 73 | + if m := _EXPERT_W1_RE.match(rest): |
| 74 | + expert_id = int(m.group(1)) |
| 75 | + return [ |
| 76 | + f"{hf_prefix}.mlp.experts.{expert_id}.gate_proj.weight", |
| 77 | + f"{hf_prefix}.mlp.experts.{expert_id}.up_proj.weight", |
| 78 | + ] |
| 79 | + if m := _EXPERT_W2_RE.match(rest): |
| 80 | + expert_id = int(m.group(1)) |
| 81 | + return [ |
| 82 | + f"{hf_prefix}.mlp.experts.{expert_id}.down_proj.weight", |
| 83 | + ] |
| 84 | + return None |
| 85 | + |
| 86 | + |
| 87 | +def _handle_mlp_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: |
| 88 | + if rest == "mlp.w1": |
| 89 | + return [ |
| 90 | + f"{hf_prefix}.mlp.gate_proj.weight", |
| 91 | + f"{hf_prefix}.mlp.up_proj.weight", |
| 92 | + ] |
| 93 | + if rest == "mlp.w2": |
| 94 | + return [ |
| 95 | + f"{hf_prefix}.mlp.down_proj.weight", |
| 96 | + ] |
| 97 | + return None |
| 98 | + |
| 99 | + |
| 100 | +def paddle_name_to_hf_names(paddle_name: str) -> List[str]: |
| 101 | + """ |
| 102 | + Mapping Function for Paddle Parameter Names to Hugging Face Names |
| 103 | + """ |
| 104 | + m = _layers_match(paddle_name) |
| 105 | + if not m: |
| 106 | + return [] |
| 107 | + idx = int(m.group(1)) |
| 108 | + rest = m.group(2) or "" |
| 109 | + |
| 110 | + hf_prefix = _get_hf_prefix_develop(idx) |
| 111 | + |
| 112 | + # 专项重命名 |
| 113 | + if rest in custom_name_map: |
| 114 | + return [f"{hf_prefix}.{custom_name_map[rest]}"] |
| 115 | + |
| 116 | + # 历史专家 |
| 117 | + if expert_names := _handle_expert_weights(hf_prefix, rest): |
| 118 | + return expert_names |
| 119 | + |
| 120 | + # 历史mlp |
| 121 | + if mlp_names := _handle_mlp_weights(hf_prefix, rest): |
| 122 | + return mlp_names |
| 123 | + |
| 124 | + return [f"{hf_prefix}.{rest}"] if rest else [hf_prefix] |
| 125 | + |
| 126 | + |
| 127 | +def prepare_tensor(tensor, pd_param, tensor_parallel_mappings, mp_degree, dst_shape): |
| 128 | + """ |
| 129 | + Converting weight tensors to match the target model’s shape involves |
| 130 | + automatically adjusting for transposing, concatenating, and slicing by columns or lengths. |
| 131 | + """ |
| 132 | + |
| 133 | + if isinstance(tensor, list): |
| 134 | + tensor = paddle.concat( |
| 135 | + [ |
| 136 | + paddle.transpose(tensor[0], perm=[1, 0]).contiguous(), |
| 137 | + paddle.transpose(tensor[1], perm=[1, 0]).contiguous(), |
| 138 | + ], |
| 139 | + axis=-1, |
| 140 | + ) |
| 141 | + # match for transpose |
| 142 | + if len(tensor.shape) == 2: |
| 143 | + if (tensor.shape[0] == dst_shape[1] or tensor.shape[1] == dst_shape[0]) and tensor.shape != dst_shape: |
| 144 | + tensor = paddle.transpose(tensor, perm=[1, 0]).contiguous() |
| 145 | + print(f"after transpose get hf tensor shape {tensor.shape}, paddle shape {dst_shape}") |
| 146 | + |
| 147 | + if mp_degree > 1 and pd_param in tensor_parallel_mappings: |
| 148 | + tensor = tensor_parallel_mappings[pd_param](tensor) |
| 149 | + if tensor.shape == dst_shape: |
| 150 | + return tensor |
| 151 | + raise ValueError(f"Unexpected tensor shape: got {tensor.shape}, want {dst_shape}") |
| 152 | + |
| 153 | + |
| 154 | +def load_paddle_model_from_safetensors( |
| 155 | + model, |
| 156 | + weight_map_path: str, |
| 157 | + ckpt_pre: str, |
| 158 | + verbose: bool = True, |
| 159 | +): |
| 160 | + """ |
| 161 | + Load safetensors into a Paddle model using the weight mappings outlined in index.json. |
| 162 | + """ |
| 163 | + |
| 164 | + tensor_parallel_mappings = {} |
| 165 | + mp_degree = fleet.get_hybrid_communicate_group().get_model_parallel_world_size() |
| 166 | + print("fuck mp degree!!!!!!!!!", mp_degree) |
| 167 | + |
| 168 | + if mp_degree > 1: |
| 169 | + print("load with mp_degree:", mp_degree) |
| 170 | + tensor_parallel_mappings = simple_safe_call(model, "get_tensor_parallel_mappings", is_split=True) |
| 171 | + tensor_parallel_mappings = add_prefix_to_keys(tensor_parallel_mappings, "_") |
| 172 | + |
| 173 | + for k, v in tensor_parallel_mappings.items(): |
| 174 | + print("tensor_parallel_mappings:", k, v) |
| 175 | + |
| 176 | + with open(weight_map_path, "r") as f: |
| 177 | + weight_map = json.load(f)["weight_map"] |
| 178 | + |
| 179 | + required_files = set() |
| 180 | + file_to_pd_param_name = defaultdict(list) |
| 181 | + pd_param_name_to_file = defaultdict(list) |
| 182 | + |
| 183 | + for pd_name, _ in model.named_parameters(): |
| 184 | + hf_names = paddle_name_to_hf_names(pd_name) |
| 185 | + if verbose: |
| 186 | + print(f"paddle_name_to_hf_names: {pd_name} -> {hf_names}") |
| 187 | + if not hf_names: |
| 188 | + if verbose: |
| 189 | + print(f"Warning: {pd_name} can not be mapped") |
| 190 | + continue |
| 191 | + for i, hf_name in enumerate(hf_names): |
| 192 | + if hf_name in weight_map: |
| 193 | + filename = weight_map[hf_name] |
| 194 | + required_files.add(filename) |
| 195 | + file_to_pd_param_name[filename].append(pd_name) |
| 196 | + if filename not in pd_param_name_to_file[pd_name]: |
| 197 | + pd_param_name_to_file[pd_name].append(filename) |
| 198 | + else: |
| 199 | + if verbose: |
| 200 | + print(f"Warning: {pd_name} -> {hf_name} not found in weight map") |
| 201 | + |
| 202 | + check_list = [] |
| 203 | + if verbose: |
| 204 | + print("---- start load param ----") |
| 205 | + for key, value in tensor_parallel_mappings.items(): |
| 206 | + print(key, value) |
| 207 | + for filename in required_files: |
| 208 | + try: |
| 209 | + with safe_open(ckpt_pre + filename, framework="paddle", device="cpu") as f: |
| 210 | + pd_params = file_to_pd_param_name[filename] |
| 211 | + for pd_param in pd_params: |
| 212 | + if pd_param in check_list: |
| 213 | + continue |
| 214 | + if verbose: |
| 215 | + print("load for pd_param:", pd_param) |
| 216 | + hf_names = paddle_name_to_hf_names(pd_param) |
| 217 | + if not hf_names: |
| 218 | + continue |
| 219 | + if len(hf_names) == 1: |
| 220 | + tensor = f.get_tensor(hf_names[0]) |
| 221 | + value = prepare_tensor( |
| 222 | + tensor, pd_param, tensor_parallel_mappings, mp_degree, model.state_dict()[pd_param].shape |
| 223 | + ) |
| 224 | + |
| 225 | + model.state_dict()[pd_param].set_value(paddle.cast(value, model.state_dict()[pd_param].dtype)) |
| 226 | + else: |
| 227 | + files = pd_param_name_to_file[pd_param] |
| 228 | + if len(files) == 1: |
| 229 | + tensor0 = f.get_tensor(hf_names[0]) |
| 230 | + tensor1 = f.get_tensor(hf_names[1]) |
| 231 | + else: |
| 232 | + if weight_map[hf_names[0]] == filename: |
| 233 | + tensor0 = f.get_tensor(hf_names[0]) |
| 234 | + with safe_open( |
| 235 | + ckpt_pre + weight_map[hf_names[1]], framework="paddle", device="cpu" |
| 236 | + ) as f2: |
| 237 | + tensor1 = f2.get_tensor(hf_names[1]) |
| 238 | + else: |
| 239 | + with safe_open( |
| 240 | + ckpt_pre + weight_map[hf_names[0]], framework="paddle", device="cpu" |
| 241 | + ) as f2: |
| 242 | + tensor0 = f2.get_tensor(hf_names[0]) |
| 243 | + tensor1 = f.get_tensor(hf_names[1]) |
| 244 | + value = prepare_tensor( |
| 245 | + [tensor0, tensor1], |
| 246 | + pd_param, |
| 247 | + tensor_parallel_mappings, |
| 248 | + mp_degree, |
| 249 | + model.state_dict()[pd_param].shape, |
| 250 | + ) |
| 251 | + model.state_dict()[pd_param].set_value(value) |
| 252 | + check_list.append(pd_param) |
| 253 | + except Exception as e: |
| 254 | + print(f"Error loading {filename}: {str(e)}") |
| 255 | + raise |
| 256 | + |
| 257 | + if verbose: |
| 258 | + print("All parameters loaded.") |
0 commit comments