Skip to content

Commit 206ec64

Browse files
Add Deepseek vLLM weight fallback mappings, and correct existing deepseek mapping
PiperOrigin-RevId: 827657629
1 parent bfdb7ed commit 206ec64

File tree

4 files changed

+819
-40
lines changed

4 files changed

+819
-40
lines changed

src/MaxText/integration/tunix/utils.py

Lines changed: 285 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ def __init__(self, model_name, config=None, use_standalone_mappings=False):
122122
self.model_name = model_name
123123
self.config = config
124124
self.use_standalone_mappings = use_standalone_mappings
125-
self._sharding_knowledge_map = _SHARDING_KNOWLEDGE_MAP
125+
self._sharding_knowledge_map = _SHARDING_KNOWLEDGE_MAP[self.model_name.split("-")[0]]
126+
self._maxtext_keys_to_vllm_keys = _MAXTEXT_TO_VLLM_KEY_MAP[self.model_name.split("-")[0]]
126127

127128
def to_hf_mapping(self):
128129
"""Returns a mapping from MaxText parameter names to HuggingFace parameter names."""
@@ -154,44 +155,29 @@ def lora_to_hf_mappings(self):
154155
return None
155156

156157
def _generalize_maxtext_key(self, maxtext_key):
157-
"""Generalizes the MaxText key to a common vLLM format."""
158-
# 'params-decoder-layers_0-mlp-...' -> 'base.decoder.layers_0.mlp....'
158+
"""
159+
Universal generalizer for Qwen3, DeepSeek, and Llama3.1 keys.
160+
Converts raw MaxText keys to a standardized 'base' format for sharding maps.
161+
"""
162+
# 1. Standardize separators and prefix
163+
# 'params-decoder-...' -> 'base.decoder....'
164+
# 'thinker.params-decoder-...' -> 'thinker.base.decoder....' (Qwen3-Omni)
159165
generic_key = maxtext_key.replace("params-", "base.").replace("-", ".")
160-
# 'base.decoder.dense_layers.mlp....' -> 'base.decoder.layers.mlp....'
161-
generic_key = re.sub(r"\.dense_layers\.", ".layers.", generic_key)
162-
# 'base.decoder.moe_layers.mlp....' -> 'base.decoder.layers.mlp....'
163-
generic_key = re.sub(r"\.moe_layers\.", ".layers.", generic_key)
164-
# '...layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0' -> '...layers.moe_block.wi_0'
165-
generic_key = re.sub(r"DeepSeekMoeBlock_0\.MoeBlock_0\.", "moe_block.", generic_key)
166-
# Handle shared experts
167-
generic_key = re.sub(
168-
r"DeepSeekMoeBlock_0\.shared_experts\.",
169-
"moe_block.shared_experts.",
170-
generic_key,
171-
)
172-
# Keep original rule for other models
173-
generic_key = re.sub(r"layers_(\d+)\.", "layers.", generic_key)
166+
167+
# 2. Normalize standard layer indexing (crucial for unscanned models)
168+
# Llama/Qwen unscanned: 'base.decoder.layers_5.self_attention...' -> 'base.decoder.layers.self_attention...'
169+
generic_key = re.sub(r"\.layers_\d+\.", ".layers.", generic_key)
170+
171+
# 3. Normalize DeepSeek specific layer indexing (if unscanned support is added later)
172+
# Preserves the distinction between 'dense_layers' and 'moe_layers'
173+
generic_key = re.sub(r"\.dense_layers_\d+\.", ".dense_layers.", generic_key)
174+
generic_key = re.sub(r"\.moe_layers_\d+\.", ".moe_layers.", generic_key)
175+
174176
return generic_key
175177

176178
def _generalize_hf_value(self, hf_value):
177179
"""Extracts and generalizes the Hugging Face name from the hf_value."""
178-
first_name = ""
179-
if isinstance(hf_value, str):
180-
first_name = hf_value
181-
elif isinstance(hf_value, list):
182-
if not hf_value:
183-
return None
184-
if isinstance(hf_value[0], list):
185-
first_name = hf_value[0][0] # Scanned MoE
186-
else:
187-
first_name = hf_value[0] # Scanned Dense / Unscanned MoE
188-
else:
189-
raise TypeError(f"Unknown value type in map: {type(hf_value)}")
190-
191-
# Replace layer and expert indices with wildcards
192-
wildcard_name = re.sub(r"layers\.(\d+)\.", "layers.*.", first_name)
193-
wildcard_name = re.sub(r"experts\.(\d+)\.", "experts.*.", wildcard_name)
194-
return wildcard_name
180+
return self._maxtext_keys_to_vllm_keys(hf_value)
195181

196182
def _correct_hf_wildcard_name(self, wildcard_name):
197183
"""Corrects the generated Hugging Face wildcard name."""
@@ -229,13 +215,10 @@ def convert_hf_map_to_sharding_map(self, hf_mapping):
229215
generic_key = self._generalize_maxtext_key(maxtext_key)
230216

231217
# 2. Generalize the Hugging Face (HF) value name
232-
wildcard_name = self._generalize_hf_value(hf_value)
233-
if wildcard_name is None:
218+
corrected_value = self._generalize_hf_value(hf_value)
219+
if corrected_value is None:
234220
continue
235221

236-
# 3. Correct the generated wildcard name
237-
corrected_name = self._correct_hf_wildcard_name(wildcard_name)
238-
239222
# 4. Look up the sharding tuple
240223
sharding_tuple = self._sharding_knowledge_map.get(generic_key)
241224

@@ -246,6 +229,268 @@ def convert_hf_map_to_sharding_map(self, hf_mapping):
246229
print(f"Warning: No sharding rule found for key: {generic_key}")
247230
continue
248231
# 5. Assemble the final map entry
249-
sharding_map[generic_key] = (corrected_name, sharding_tuple)
232+
sharding_map[generic_key] = (corrected_value, sharding_tuple)
250233

251234
return sharding_map
235+
236+
237+
def GENERAL_HF_KEYS_TO_VLLM_KEYS(hf_value):
238+
"""Converts a concrete HF key (or list of keys) to a vLLM template string."""
239+
first_name = ""
240+
if isinstance(hf_value, str):
241+
first_name = hf_value
242+
elif isinstance(hf_value, list):
243+
if not hf_value:
244+
return None
245+
if isinstance(hf_value[0], list):
246+
first_name = hf_value[0][0] # Scanned MoE
247+
else:
248+
first_name = hf_value[0] # Scanned Dense / Unscanned MoE
249+
else:
250+
raise TypeError(f"Unknown value type in map: {type(hf_value)}")
251+
252+
# Replace layer and expert indices with wildcards
253+
wildcard_name = re.sub(r"layers\.(\d+)\.", "layers.*.", first_name)
254+
wildcard_name = re.sub(r"experts\.(\d+)\.", "experts.*.", wildcard_name)
255+
256+
if "layernorm.weight" in wildcard_name or "_norm.weight" in wildcard_name:
257+
# Fix all layer norms
258+
wildcard_name = wildcard_name.replace(".weight", ".scale")
259+
elif wildcard_name == "model.embed_tokens.weight":
260+
wildcard_name = "model.embed.embedding"
261+
elif wildcard_name == "lm_head.weight":
262+
wildcard_name = "model.lm_head"
263+
elif wildcard_name == "model.norm.weight":
264+
wildcard_name = "model.norm.scale"
265+
elif wildcard_name.endswith(".weight"):
266+
# Fix all other weights (MLP, Attn)
267+
wildcard_name = wildcard_name.replace(".weight", ".kernel")
268+
return wildcard_name
269+
270+
271+
def DEEPSEEK_HF_KEYS_TO_VLLM_KEYS(hf_input):
272+
"""
273+
Converts a concrete HF key (or list of keys) to a vLLM template string.
274+
Handles both single strings and lists of strings.
275+
"""
276+
if not hf_input:
277+
return None
278+
279+
# 1. Standardize input to a single representative sample string
280+
if isinstance(hf_input, str):
281+
sample_key = hf_input
282+
elif isinstance(hf_input, list):
283+
if not hf_input:
284+
return None
285+
if isinstance(hf_input[0], list):
286+
sample_key = hf_input[0][0] # Scanned MoE
287+
else:
288+
sample_key = hf_input[0] # Scanned Dense / Unscanned MoE
289+
else:
290+
raise TypeError(f"Unknown value type in map: {type(hf_input)}")
291+
292+
# 2. Structural Generalization (convert specific indices to wildcards)
293+
# Replace 'model.layers.{N}.' with 'layers.*.'
294+
295+
template = re.sub(r"^model\.layers\.\d+\.", "layers.*.", sample_key)
296+
297+
# 3. Leaf Node Renaming (HF -> vLLM intermediate names)
298+
leaf_replacements = [
299+
# --- Globals ---
300+
(r"^model\.norm\.weight$", "final_norm.scale"),
301+
(r"^model\.embed_tokens\.weight$", "embedder.input_embedding_table_VD"),
302+
(r"^lm_head\.weight$", "lm_head.input_embedding_table_DV"),
303+
# --- MoE: Router (Gate) ---
304+
# specific to DeepSeek's 'mlp.gate'
305+
(r"mlp\.gate\.weight$", "custom_module.router.kernel_DE"),
306+
(r"mlp\.gate\.e_score_correction_bias$", "custom_module.router.bias_E"),
307+
# --- MoE: Shared Experts ---
308+
# specific to DeepSeek's 'mlp.shared_experts'
309+
(r"mlp\.shared_experts\.gate_proj\.weight$", "shared_experts.kernel_gating_DF"),
310+
(r"mlp\.shared_experts\.up_proj\.weight$", "shared_experts.kernel_up_proj_DF"),
311+
(r"mlp\.shared_experts\.down_proj\.weight$", "shared_experts.kernel_down_proj_FD"),
312+
# --- MoE: Routed Experts (individual experts) ---
313+
# Matches 'mlp.experts.0.gate_proj.weight' etc.
314+
# We generalize the expert ID to 'experts.*' if you want them all to map to one template key
315+
# OR if your target mapping uses 'custom_module.kernel_gating_EDF' directly for all:
316+
(r"mlp\.experts\.\d+\.gate_proj\.weight$", "custom_module.kernel_gating_EDF"),
317+
(r"mlp\.experts\.\d+\.up_proj\.weight$", "custom_module.kernel_up_proj_EDF"),
318+
(r"mlp\.experts\.\d+\.down_proj\.weight$", "custom_module.kernel_down_proj_EFD"),
319+
# --- Standard Dense MLP (Fallback for non-MoE layers) ---
320+
(r"mlp\.gate_proj\.weight$", "custom_module.kernel_gating_DF"),
321+
(r"mlp\.up_proj\.weight$", "custom_module.kernel_up_proj_DF"),
322+
(r"mlp\.down_proj\.weight$", "custom_module.kernel_down_proj_FD"),
323+
# --- Attention & Norms (Standard) ---
324+
(r"input_layernorm\.weight$", "pre_attention_norm.scale"),
325+
(r"post_attention_layernorm\.weight$", "pre_mlp_norm.scale"),
326+
(r"self_attn\.q_a_layernorm\.weight$", "attn.q_rms_norm.scale"),
327+
(r"self_attn\.kv_a_layernorm\.weight$", "attn.kv_rms_norm.scale"),
328+
(r"self_attn\.q_proj\.weight$", "attn.kernel_q_proj"),
329+
(r"self_attn\.q_a_proj\.weight$", "attn.kernel_q_down_proj_DA"),
330+
(r"self_attn\.q_b_proj\.weight$", "attn.kernel_q_up_proj_ANH"),
331+
(r"self_attn\.kv_a_proj_with_mqa\.weight$", "attn.kernel_kv_down_proj_DA"),
332+
(r"self_attn\.kv_b_proj\.weight$", "attn.kernel_kv_up_proj_ANH"),
333+
(r"self_attn\.o_proj\.weight$", "attn.kernel_o_proj_NHD"),
334+
]
335+
336+
for pattern, replacement in leaf_replacements:
337+
template = re.sub(pattern, replacement, template)
338+
339+
return template
340+
341+
342+
GENERAL_SHARDING_MAP = {
343+
# Non-layer parameters
344+
"base.token_embedder.embedding": ("model", None),
345+
"base.decoder.decoder_norm.scale": (None,),
346+
"base.decoder.logits_dense.kernel": (None, "model"),
347+
# --- Attention (generic for scanned/unscanned) ---
348+
"base.decoder.layers.pre_self_attention_layer_norm.scale": (None, "layer"),
349+
"base.decoder.layers.self_attention.query.kernel": (
350+
None,
351+
"layer",
352+
"model",
353+
None,
354+
),
355+
"base.decoder.layers.self_attention.key.kernel": (
356+
None,
357+
"layer",
358+
"model",
359+
None,
360+
),
361+
"base.decoder.layers.self_attention.value.kernel": (
362+
None,
363+
"layer",
364+
"model",
365+
None,
366+
),
367+
"base.decoder.layers.self_attention.query_norm.scale": (None, "layer"),
368+
"base.decoder.layers.self_attention.key_norm.scale": (None, "layer"),
369+
"base.decoder.layers.self_attention.out.kernel": (
370+
"model",
371+
"layer",
372+
None,
373+
None,
374+
),
375+
"base.decoder.layers.post_self_attention_layer_norm.scale": (None, "layer"),
376+
# --- Dense MLP (generic for scanned/unscanned) ---
377+
"base.decoder.layers.mlp.wi_0.kernel": (None, "layer", "model"),
378+
"base.decoder.layers.mlp.wi_1.kernel": (None, "layer", "model"),
379+
"base.decoder.layers.mlp.wo.kernel": ("model", "layer", None),
380+
# --- MoE (generic for scanned/unscanned) ---
381+
"base.decoder.layers.moe_block.gate.kernel": (None, "layer", "model"),
382+
"base.decoder.layers.moe_block.wi_0": ("expert", "layer", None, "model"),
383+
"base.decoder.layers.moe_block.wi_1": ("expert", "layer", None, "model"),
384+
"base.decoder.layers.moe_block.wo": ("expert", "layer", "model", None),
385+
# --- Deepseek Attention ---
386+
"base.decoder.layers.self_attention.wq_a.kernel": (
387+
None,
388+
"layer",
389+
"model",
390+
None,
391+
),
392+
"base.decoder.layers.self_attention.wq_b.kernel": (
393+
None,
394+
"layer",
395+
"model",
396+
None,
397+
),
398+
"base.decoder.layers.self_attention.q_norm.scale": (None, "layer"),
399+
"base.decoder.layers.self_attention.wkv_a.kernel": (
400+
None,
401+
"layer",
402+
"model",
403+
None,
404+
),
405+
"base.decoder.layers.self_attention.wkv_b.kernel": (
406+
None,
407+
"layer",
408+
"model",
409+
None,
410+
),
411+
"base.decoder.layers.self_attention.kv_norm.scale": (None, "layer"),
412+
# --- Deepseek MoE ---
413+
"base.decoder.layers.moe_block.shared_experts.wi_0.kernel": (
414+
None,
415+
"layer",
416+
"model",
417+
),
418+
"base.decoder.layers.moe_block.shared_experts.wi_1.kernel": (
419+
None,
420+
"layer",
421+
"model",
422+
),
423+
"base.decoder.layers.moe_block.shared_experts.wo.kernel": (
424+
"model",
425+
"layer",
426+
None,
427+
),
428+
"base.decoder.layers.moe_block.gate.bias": (None, "layer", "model"),
429+
}
430+
431+
432+
DEEPSEEK_SHARDING_MAP = {
433+
# --- Non-Layer Parameters ---
434+
"base.token_embedder.embedding": ("model", None),
435+
"base.decoder.decoder_norm.scale": (None,),
436+
"base.decoder.logits_dense.kernel": (None, "model"),
437+
# ==============================================================================
438+
# DENSE LAYERS MAPPING
439+
# ==============================================================================
440+
"base.decoder.dense_layers.pre_self_attention_layer_norm.scale": (None, "layer"),
441+
"base.decoder.dense_layers.post_self_attention_layer_norm.scale": (None, "layer"),
442+
# --- Attention (MLA) ---
443+
# Q projections (Down/Up)
444+
"base.decoder.dense_layers.self_attention.wq_a.kernel": (None, "layer", "model", None),
445+
"base.decoder.dense_layers.self_attention.wq_b.kernel": (None, "layer", "model", None),
446+
# KV projections (Down/Up with MQA)
447+
"base.decoder.dense_layers.self_attention.wkv_a.kernel": (None, "layer", "model", None),
448+
"base.decoder.dense_layers.self_attention.wkv_b.kernel": (None, "layer", "model", None),
449+
# Output projection
450+
"base.decoder.dense_layers.self_attention.out.kernel": ("model", "layer", None, None),
451+
# MLA Norms
452+
"base.decoder.dense_layers.self_attention.kv_norm.scale": (None, "layer"),
453+
"base.decoder.dense_layers.self_attention.q_norm.scale": (None, "layer"),
454+
# --- Dense MLP ---
455+
"base.decoder.dense_layers.mlp.wi_0.kernel": (None, "layer", "model"),
456+
"base.decoder.dense_layers.mlp.wi_1.kernel": (None, "layer", "model"),
457+
"base.decoder.dense_layers.mlp.wo.kernel": ("model", "layer", None),
458+
# ==============================================================================
459+
# MOE LAYERS MAPPING
460+
# ==============================================================================
461+
"base.decoder.moe_layers.pre_self_attention_layer_norm.scale": (None, "layer"),
462+
"base.decoder.moe_layers.post_self_attention_layer_norm.scale": (None, "layer"),
463+
# --- Attention (MLA + Decoupled RoPE) for MoE Layers ---
464+
"base.decoder.moe_layers.self_attention.wq_a.kernel": (None, "layer", "model", None),
465+
"base.decoder.moe_layers.self_attention.wq_b.kernel": (None, "layer", "model", None),
466+
"base.decoder.moe_layers.self_attention.wkv_a.kernel": (None, "layer", "model", None),
467+
"base.decoder.moe_layers.self_attention.wkv_b.kernel": (None, "layer", "model", None),
468+
"base.decoder.moe_layers.self_attention.out.kernel": ("model", "layer", None, None),
469+
"base.decoder.moe_layers.self_attention.kv_norm.scale": (None, "layer"),
470+
"base.decoder.moe_layers.self_attention.q_norm.scale": (None, "layer"),
471+
# --- DeepSeek MoE Blocks ---
472+
# Shared Experts
473+
"base.decoder.moe_layers.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel": (None, "layer", "model"),
474+
"base.decoder.moe_layers.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel": (None, "layer", "model"),
475+
"base.decoder.moe_layers.DeepSeekMoeBlock_0.shared_experts.wo.kernel": ("model", "layer", None),
476+
# Gating (Router)
477+
"base.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel": (None, "layer", "model"),
478+
"base.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias": (None, "layer", "model"),
479+
# Routed Experts
480+
"base.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0": ("expert", "layer", None, "model"),
481+
"base.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_1": ("expert", "layer", None, "model"),
482+
"base.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.wo": ("expert", "layer", "model", None),
483+
}
484+
485+
486+
_SHARDING_KNOWLEDGE_MAP = {
487+
"qwen3": GENERAL_SHARDING_MAP,
488+
"llama3": GENERAL_SHARDING_MAP,
489+
"deepseek3": DEEPSEEK_SHARDING_MAP,
490+
}
491+
492+
_MAXTEXT_TO_VLLM_KEY_MAP = {
493+
"qwen3": GENERAL_HF_KEYS_TO_VLLM_KEYS,
494+
"llama3": GENERAL_HF_KEYS_TO_VLLM_KEYS,
495+
"deepseek3": DEEPSEEK_HF_KEYS_TO_VLLM_KEYS,
496+
}

src/MaxText/integration/tunix/weight_mapping/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
model name. This allows for easy extension to support new models.
2020
"""
2121

22+
from maxtext.src.maxtext.integration.tunix.weight_mapping.deepseek3 import DEEPSEEK_VLLM_MAPPING
2223
from maxtext.src.maxtext.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
2324
from maxtext.src.maxtext.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING
2425

@@ -31,6 +32,8 @@ def __getattr__(self, name):
3132
return LLAMA3_VLLM_MAPPING
3233
elif name.startswith("qwen3"):
3334
return QWEN3_VLLM_MAPPING
35+
elif name.startswith("deepseek3"):
36+
return DEEPSEEK_VLLM_MAPPING
3437
else:
3538
raise ValueError(f"{name} vLLM weight mapping not found.")
3639

0 commit comments

Comments
 (0)