@@ -122,7 +122,8 @@ def __init__(self, model_name, config=None, use_standalone_mappings=False):
122122 self .model_name = model_name
123123 self .config = config
124124 self .use_standalone_mappings = use_standalone_mappings
125- self ._sharding_knowledge_map = _SHARDING_KNOWLEDGE_MAP
125+ self ._sharding_knowledge_map = _SHARDING_KNOWLEDGE_MAP [self .model_name .split ("-" )[0 ]]
126+ self ._maxtext_keys_to_vllm_keys = _MAXTEXT_TO_VLLM_KEY_MAP [self .model_name .split ("-" )[0 ]]
126127
127128 def to_hf_mapping (self ):
128129 """Returns a mapping from MaxText parameter names to HuggingFace parameter names."""
@@ -154,44 +155,29 @@ def lora_to_hf_mappings(self):
154155 return None
155156
156157 def _generalize_maxtext_key (self , maxtext_key ):
157- """Generalizes the MaxText key to a common vLLM format."""
158- # 'params-decoder-layers_0-mlp-...' -> 'base.decoder.layers_0.mlp....'
158+ """
159+ Universal generalizer for Qwen3, DeepSeek, and Llama3.1 keys.
160+ Converts raw MaxText keys to a standardized 'base' format for sharding maps.
161+ """
162+ # 1. Standardize separators and prefix
163+ # 'params-decoder-...' -> 'base.decoder....'
164+ # 'thinker.params-decoder-...' -> 'thinker.base.decoder....' (Qwen3-Omni)
159165 generic_key = maxtext_key .replace ("params-" , "base." ).replace ("-" , "." )
160- # 'base.decoder.dense_layers.mlp....' -> 'base.decoder.layers.mlp....'
161- generic_key = re .sub (r"\.dense_layers\." , ".layers." , generic_key )
162- # 'base.decoder.moe_layers.mlp....' -> 'base.decoder.layers.mlp....'
163- generic_key = re .sub (r"\.moe_layers\." , ".layers." , generic_key )
164- # '...layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0' -> '...layers.moe_block.wi_0'
165- generic_key = re .sub (r"DeepSeekMoeBlock_0\.MoeBlock_0\." , "moe_block." , generic_key )
166- # Handle shared experts
167- generic_key = re .sub (
168- r"DeepSeekMoeBlock_0\.shared_experts\." ,
169- "moe_block.shared_experts." ,
170- generic_key ,
171- )
172- # Keep original rule for other models
173- generic_key = re .sub (r"layers_(\d+)\." , "layers." , generic_key )
166+
167+ # 2. Normalize standard layer indexing (crucial for unscanned models)
168+ # Llama/Qwen unscanned: 'base.decoder.layers_5.self_attention...' -> 'base.decoder.layers.self_attention...'
169+ generic_key = re .sub (r"\.layers_\d+\." , ".layers." , generic_key )
170+
171+ # 3. Normalize DeepSeek specific layer indexing (if unscanned support is added later)
172+ # Preserves the distinction between 'dense_layers' and 'moe_layers'
173+ generic_key = re .sub (r"\.dense_layers_\d+\." , ".dense_layers." , generic_key )
174+ generic_key = re .sub (r"\.moe_layers_\d+\." , ".moe_layers." , generic_key )
175+
174176 return generic_key
175177
176178 def _generalize_hf_value (self , hf_value ):
177179 """Extracts and generalizes the Hugging Face name from the hf_value."""
178- first_name = ""
179- if isinstance (hf_value , str ):
180- first_name = hf_value
181- elif isinstance (hf_value , list ):
182- if not hf_value :
183- return None
184- if isinstance (hf_value [0 ], list ):
185- first_name = hf_value [0 ][0 ] # Scanned MoE
186- else :
187- first_name = hf_value [0 ] # Scanned Dense / Unscanned MoE
188- else :
189- raise TypeError (f"Unknown value type in map: { type (hf_value )} " )
190-
191- # Replace layer and expert indices with wildcards
192- wildcard_name = re .sub (r"layers\.(\d+)\." , "layers.*." , first_name )
193- wildcard_name = re .sub (r"experts\.(\d+)\." , "experts.*." , wildcard_name )
194- return wildcard_name
180+ return self ._maxtext_keys_to_vllm_keys (hf_value )
195181
196182 def _correct_hf_wildcard_name (self , wildcard_name ):
197183 """Corrects the generated Hugging Face wildcard name."""
@@ -229,13 +215,10 @@ def convert_hf_map_to_sharding_map(self, hf_mapping):
229215 generic_key = self ._generalize_maxtext_key (maxtext_key )
230216
231217 # 2. Generalize the Hugging Face (HF) value name
232- wildcard_name = self ._generalize_hf_value (hf_value )
233- if wildcard_name is None :
218+ corrected_value = self ._generalize_hf_value (hf_value )
219+ if corrected_value is None :
234220 continue
235221
236- # 3. Correct the generated wildcard name
237- corrected_name = self ._correct_hf_wildcard_name (wildcard_name )
238-
239222 # 4. Look up the sharding tuple
240223 sharding_tuple = self ._sharding_knowledge_map .get (generic_key )
241224
@@ -246,6 +229,268 @@ def convert_hf_map_to_sharding_map(self, hf_mapping):
246229 print (f"Warning: No sharding rule found for key: { generic_key } " )
247230 continue
248231 # 5. Assemble the final map entry
249- sharding_map [generic_key ] = (corrected_name , sharding_tuple )
232+ sharding_map [generic_key ] = (corrected_value , sharding_tuple )
250233
251234 return sharding_map
235+
236+
237+ def GENERAL_HF_KEYS_TO_VLLM_KEYS (hf_value ):
238+ """Converts a concrete HF key (or list of keys) to a vLLM template string."""
239+ first_name = ""
240+ if isinstance (hf_value , str ):
241+ first_name = hf_value
242+ elif isinstance (hf_value , list ):
243+ if not hf_value :
244+ return None
245+ if isinstance (hf_value [0 ], list ):
246+ first_name = hf_value [0 ][0 ] # Scanned MoE
247+ else :
248+ first_name = hf_value [0 ] # Scanned Dense / Unscanned MoE
249+ else :
250+ raise TypeError (f"Unknown value type in map: { type (hf_value )} " )
251+
252+ # Replace layer and expert indices with wildcards
253+ wildcard_name = re .sub (r"layers\.(\d+)\." , "layers.*." , first_name )
254+ wildcard_name = re .sub (r"experts\.(\d+)\." , "experts.*." , wildcard_name )
255+
256+ if "layernorm.weight" in wildcard_name or "_norm.weight" in wildcard_name :
257+ # Fix all layer norms
258+ wildcard_name = wildcard_name .replace (".weight" , ".scale" )
259+ elif wildcard_name == "model.embed_tokens.weight" :
260+ wildcard_name = "model.embed.embedding"
261+ elif wildcard_name == "lm_head.weight" :
262+ wildcard_name = "model.lm_head"
263+ elif wildcard_name == "model.norm.weight" :
264+ wildcard_name = "model.norm.scale"
265+ elif wildcard_name .endswith (".weight" ):
266+ # Fix all other weights (MLP, Attn)
267+ wildcard_name = wildcard_name .replace (".weight" , ".kernel" )
268+ return wildcard_name
269+
270+
271+ def DEEPSEEK_HF_KEYS_TO_VLLM_KEYS (hf_input ):
272+ """
273+ Converts a concrete HF key (or list of keys) to a vLLM template string.
274+ Handles both single strings and lists of strings.
275+ """
276+ if not hf_input :
277+ return None
278+
279+ # 1. Standardize input to a single representative sample string
280+ if isinstance (hf_input , str ):
281+ sample_key = hf_input
282+ elif isinstance (hf_input , list ):
283+ if not hf_input :
284+ return None
285+ if isinstance (hf_input [0 ], list ):
286+ sample_key = hf_input [0 ][0 ] # Scanned MoE
287+ else :
288+ sample_key = hf_input [0 ] # Scanned Dense / Unscanned MoE
289+ else :
290+ raise TypeError (f"Unknown value type in map: { type (hf_input )} " )
291+
292+ # 2. Structural Generalization (convert specific indices to wildcards)
293+ # Replace 'model.layers.{N}.' with 'layers.*.'
294+
295+ template = re .sub (r"^model\.layers\.\d+\." , "layers.*." , sample_key )
296+
297+ # 3. Leaf Node Renaming (HF -> vLLM intermediate names)
298+ leaf_replacements = [
299+ # --- Globals ---
300+ (r"^model\.norm\.weight$" , "final_norm.scale" ),
301+ (r"^model\.embed_tokens\.weight$" , "embedder.input_embedding_table_VD" ),
302+ (r"^lm_head\.weight$" , "lm_head.input_embedding_table_DV" ),
303+ # --- MoE: Router (Gate) ---
304+ # specific to DeepSeek's 'mlp.gate'
305+ (r"mlp\.gate\.weight$" , "custom_module.router.kernel_DE" ),
306+ (r"mlp\.gate\.e_score_correction_bias$" , "custom_module.router.bias_E" ),
307+ # --- MoE: Shared Experts ---
308+ # specific to DeepSeek's 'mlp.shared_experts'
309+ (r"mlp\.shared_experts\.gate_proj\.weight$" , "shared_experts.kernel_gating_DF" ),
310+ (r"mlp\.shared_experts\.up_proj\.weight$" , "shared_experts.kernel_up_proj_DF" ),
311+ (r"mlp\.shared_experts\.down_proj\.weight$" , "shared_experts.kernel_down_proj_FD" ),
312+ # --- MoE: Routed Experts (individual experts) ---
313+ # Matches 'mlp.experts.0.gate_proj.weight' etc.
314+ # We generalize the expert ID to 'experts.*' if you want them all to map to one template key
315+ # OR if your target mapping uses 'custom_module.kernel_gating_EDF' directly for all:
316+ (r"mlp\.experts\.\d+\.gate_proj\.weight$" , "custom_module.kernel_gating_EDF" ),
317+ (r"mlp\.experts\.\d+\.up_proj\.weight$" , "custom_module.kernel_up_proj_EDF" ),
318+ (r"mlp\.experts\.\d+\.down_proj\.weight$" , "custom_module.kernel_down_proj_EFD" ),
319+ # --- Standard Dense MLP (Fallback for non-MoE layers) ---
320+ (r"mlp\.gate_proj\.weight$" , "custom_module.kernel_gating_DF" ),
321+ (r"mlp\.up_proj\.weight$" , "custom_module.kernel_up_proj_DF" ),
322+ (r"mlp\.down_proj\.weight$" , "custom_module.kernel_down_proj_FD" ),
323+ # --- Attention & Norms (Standard) ---
324+ (r"input_layernorm\.weight$" , "pre_attention_norm.scale" ),
325+ (r"post_attention_layernorm\.weight$" , "pre_mlp_norm.scale" ),
326+ (r"self_attn\.q_a_layernorm\.weight$" , "attn.q_rms_norm.scale" ),
327+ (r"self_attn\.kv_a_layernorm\.weight$" , "attn.kv_rms_norm.scale" ),
328+ (r"self_attn\.q_proj\.weight$" , "attn.kernel_q_proj" ),
329+ (r"self_attn\.q_a_proj\.weight$" , "attn.kernel_q_down_proj_DA" ),
330+ (r"self_attn\.q_b_proj\.weight$" , "attn.kernel_q_up_proj_ANH" ),
331+ (r"self_attn\.kv_a_proj_with_mqa\.weight$" , "attn.kernel_kv_down_proj_DA" ),
332+ (r"self_attn\.kv_b_proj\.weight$" , "attn.kernel_kv_up_proj_ANH" ),
333+ (r"self_attn\.o_proj\.weight$" , "attn.kernel_o_proj_NHD" ),
334+ ]
335+
336+ for pattern , replacement in leaf_replacements :
337+ template = re .sub (pattern , replacement , template )
338+
339+ return template
340+
341+
342+ GENERAL_SHARDING_MAP = {
343+ # Non-layer parameters
344+ "base.token_embedder.embedding" : ("model" , None ),
345+ "base.decoder.decoder_norm.scale" : (None ,),
346+ "base.decoder.logits_dense.kernel" : (None , "model" ),
347+ # --- Attention (generic for scanned/unscanned) ---
348+ "base.decoder.layers.pre_self_attention_layer_norm.scale" : (None , "layer" ),
349+ "base.decoder.layers.self_attention.query.kernel" : (
350+ None ,
351+ "layer" ,
352+ "model" ,
353+ None ,
354+ ),
355+ "base.decoder.layers.self_attention.key.kernel" : (
356+ None ,
357+ "layer" ,
358+ "model" ,
359+ None ,
360+ ),
361+ "base.decoder.layers.self_attention.value.kernel" : (
362+ None ,
363+ "layer" ,
364+ "model" ,
365+ None ,
366+ ),
367+ "base.decoder.layers.self_attention.query_norm.scale" : (None , "layer" ),
368+ "base.decoder.layers.self_attention.key_norm.scale" : (None , "layer" ),
369+ "base.decoder.layers.self_attention.out.kernel" : (
370+ "model" ,
371+ "layer" ,
372+ None ,
373+ None ,
374+ ),
375+ "base.decoder.layers.post_self_attention_layer_norm.scale" : (None , "layer" ),
376+ # --- Dense MLP (generic for scanned/unscanned) ---
377+ "base.decoder.layers.mlp.wi_0.kernel" : (None , "layer" , "model" ),
378+ "base.decoder.layers.mlp.wi_1.kernel" : (None , "layer" , "model" ),
379+ "base.decoder.layers.mlp.wo.kernel" : ("model" , "layer" , None ),
380+ # --- MoE (generic for scanned/unscanned) ---
381+ "base.decoder.layers.moe_block.gate.kernel" : (None , "layer" , "model" ),
382+ "base.decoder.layers.moe_block.wi_0" : ("expert" , "layer" , None , "model" ),
383+ "base.decoder.layers.moe_block.wi_1" : ("expert" , "layer" , None , "model" ),
384+ "base.decoder.layers.moe_block.wo" : ("expert" , "layer" , "model" , None ),
385+ # --- Deepseek Attention ---
386+ "base.decoder.layers.self_attention.wq_a.kernel" : (
387+ None ,
388+ "layer" ,
389+ "model" ,
390+ None ,
391+ ),
392+ "base.decoder.layers.self_attention.wq_b.kernel" : (
393+ None ,
394+ "layer" ,
395+ "model" ,
396+ None ,
397+ ),
398+ "base.decoder.layers.self_attention.q_norm.scale" : (None , "layer" ),
399+ "base.decoder.layers.self_attention.wkv_a.kernel" : (
400+ None ,
401+ "layer" ,
402+ "model" ,
403+ None ,
404+ ),
405+ "base.decoder.layers.self_attention.wkv_b.kernel" : (
406+ None ,
407+ "layer" ,
408+ "model" ,
409+ None ,
410+ ),
411+ "base.decoder.layers.self_attention.kv_norm.scale" : (None , "layer" ),
412+ # --- Deepseek MoE ---
413+ "base.decoder.layers.moe_block.shared_experts.wi_0.kernel" : (
414+ None ,
415+ "layer" ,
416+ "model" ,
417+ ),
418+ "base.decoder.layers.moe_block.shared_experts.wi_1.kernel" : (
419+ None ,
420+ "layer" ,
421+ "model" ,
422+ ),
423+ "base.decoder.layers.moe_block.shared_experts.wo.kernel" : (
424+ "model" ,
425+ "layer" ,
426+ None ,
427+ ),
428+ "base.decoder.layers.moe_block.gate.bias" : (None , "layer" , "model" ),
429+ }
430+
431+
432+ DEEPSEEK_SHARDING_MAP = {
433+ # --- Non-Layer Parameters ---
434+ "base.token_embedder.embedding" : ("model" , None ),
435+ "base.decoder.decoder_norm.scale" : (None ,),
436+ "base.decoder.logits_dense.kernel" : (None , "model" ),
437+ # ==============================================================================
438+ # DENSE LAYERS MAPPING
439+ # ==============================================================================
440+ "base.decoder.dense_layers.pre_self_attention_layer_norm.scale" : (None , "layer" ),
441+ "base.decoder.dense_layers.post_self_attention_layer_norm.scale" : (None , "layer" ),
442+ # --- Attention (MLA) ---
443+ # Q projections (Down/Up)
444+ "base.decoder.dense_layers.self_attention.wq_a.kernel" : (None , "layer" , "model" , None ),
445+ "base.decoder.dense_layers.self_attention.wq_b.kernel" : (None , "layer" , "model" , None ),
446+ # KV projections (Down/Up with MQA)
447+ "base.decoder.dense_layers.self_attention.wkv_a.kernel" : (None , "layer" , "model" , None ),
448+ "base.decoder.dense_layers.self_attention.wkv_b.kernel" : (None , "layer" , "model" , None ),
449+ # Output projection
450+ "base.decoder.dense_layers.self_attention.out.kernel" : ("model" , "layer" , None , None ),
451+ # MLA Norms
452+ "base.decoder.dense_layers.self_attention.kv_norm.scale" : (None , "layer" ),
453+ "base.decoder.dense_layers.self_attention.q_norm.scale" : (None , "layer" ),
454+ # --- Dense MLP ---
455+ "base.decoder.dense_layers.mlp.wi_0.kernel" : (None , "layer" , "model" ),
456+ "base.decoder.dense_layers.mlp.wi_1.kernel" : (None , "layer" , "model" ),
457+ "base.decoder.dense_layers.mlp.wo.kernel" : ("model" , "layer" , None ),
458+ # ==============================================================================
459+ # MOE LAYERS MAPPING
460+ # ==============================================================================
461+ "base.decoder.moe_layers.pre_self_attention_layer_norm.scale" : (None , "layer" ),
462+ "base.decoder.moe_layers.post_self_attention_layer_norm.scale" : (None , "layer" ),
463+ # --- Attention (MLA + Decoupled RoPE) for MoE Layers ---
464+ "base.decoder.moe_layers.self_attention.wq_a.kernel" : (None , "layer" , "model" , None ),
465+ "base.decoder.moe_layers.self_attention.wq_b.kernel" : (None , "layer" , "model" , None ),
466+ "base.decoder.moe_layers.self_attention.wkv_a.kernel" : (None , "layer" , "model" , None ),
467+ "base.decoder.moe_layers.self_attention.wkv_b.kernel" : (None , "layer" , "model" , None ),
468+ "base.decoder.moe_layers.self_attention.out.kernel" : ("model" , "layer" , None , None ),
469+ "base.decoder.moe_layers.self_attention.kv_norm.scale" : (None , "layer" ),
470+ "base.decoder.moe_layers.self_attention.q_norm.scale" : (None , "layer" ),
471+ # --- DeepSeek MoE Blocks ---
472+ # Shared Experts
473+ "base.decoder.moe_layers.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel" : (None , "layer" , "model" ),
474+ "base.decoder.moe_layers.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel" : (None , "layer" , "model" ),
475+ "base.decoder.moe_layers.DeepSeekMoeBlock_0.shared_experts.wo.kernel" : ("model" , "layer" , None ),
476+ # Gating (Router)
477+ "base.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel" : (None , "layer" , "model" ),
478+ "base.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias" : (None , "layer" , "model" ),
479+ # Routed Experts
480+ "base.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0" : ("expert" , "layer" , None , "model" ),
481+ "base.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_1" : ("expert" , "layer" , None , "model" ),
482+ "base.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.wo" : ("expert" , "layer" , "model" , None ),
483+ }
484+
485+
486+ _SHARDING_KNOWLEDGE_MAP = {
487+ "qwen3" : GENERAL_SHARDING_MAP ,
488+ "llama3" : GENERAL_SHARDING_MAP ,
489+ "deepseek3" : DEEPSEEK_SHARDING_MAP ,
490+ }
491+
492+ _MAXTEXT_TO_VLLM_KEY_MAP = {
493+ "qwen3" : GENERAL_HF_KEYS_TO_VLLM_KEYS ,
494+ "llama3" : GENERAL_HF_KEYS_TO_VLLM_KEYS ,
495+ "deepseek3" : DEEPSEEK_HF_KEYS_TO_VLLM_KEYS ,
496+ }
0 commit comments