Skip to content

Commit fef2f1f

Browse files
authored
[https://nvbugs/5449155][fix] Fix DeepSeek R1 weight loading for TP16 (#6913)
Signed-off-by: Aurelien Chartier <[email protected]>
1 parent 790a105 commit fef2f1f

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -604,9 +604,12 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
604604
load_weights_vanilla_helper(module, weights)
605605

606606
scale_name = self._get_scale_name(weights)
607-
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
608-
module.tp_rank,
609-
module.tp_mode).squeeze()
607+
full_weight_scale = weights[0][scale_name]
608+
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
609+
if full_weight_scale.dim() == 4:
610+
full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1)
611+
weight_scale = load_weight_shard(full_weight_scale, module.tp_size,
612+
module.tp_rank, module.tp_mode)
610613
copy_weight(module.weight_scale, weight_scale)
611614
if "input_scale" in weights[0]:
612615
copy_weight(module.input_scale, weights[0]["input_scale"])
@@ -619,13 +622,23 @@ def load_weights_fused_qkv_linear(self, module: Linear,
619622
fused_weight = torch.cat((q_weight, k_weight, v_weight))
620623

621624
scale_name = self._get_scale_name(weights)
622-
q_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
625+
full_q_scale = weights[0][scale_name]
626+
full_k_scale = weights[1][scale_name]
627+
full_v_scale = weights[2][scale_name]
628+
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
629+
if full_q_scale.dim() == 4:
630+
full_q_scale = full_q_scale.squeeze(1).squeeze(-1)
631+
if full_k_scale.dim() == 4:
632+
full_k_scale = full_k_scale.squeeze(1).squeeze(-1)
633+
if full_v_scale.dim() == 4:
634+
full_v_scale = full_v_scale.squeeze(1).squeeze(-1)
635+
q_scale = load_weight_shard(full_q_scale, module.tp_size,
623636
module.tp_rank, module.tp_mode)
624-
k_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
637+
k_scale = load_weight_shard(full_k_scale, module.tp_size,
625638
module.tp_rank, module.tp_mode)
626-
v_scale = load_weight_shard(weights[2][scale_name], module.tp_size,
639+
v_scale = load_weight_shard(full_v_scale, module.tp_size,
627640
module.tp_rank, module.tp_mode)
628-
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze()
641+
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
629642

630643
copy_weight(module.weight, fused_weight)
631644
copy_weight(module.weight_scale, fused_fp8_block_scale)
@@ -637,11 +650,18 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
637650
fused_weight = torch.cat((gate_weight, up_weight))
638651

639652
scale_name = self._get_scale_name(weights)
640-
left_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
653+
full_left_scale = weights[0][scale_name]
654+
full_right_scale = weights[1][scale_name]
655+
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
656+
if full_left_scale.dim() == 4:
657+
full_left_scale = full_left_scale.squeeze(1).squeeze(-1)
658+
if full_right_scale.dim() == 4:
659+
full_right_scale = full_right_scale.squeeze(1).squeeze(-1)
660+
left_scale = load_weight_shard(full_left_scale, module.tp_size,
641661
module.tp_rank, module.tp_mode)
642-
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
662+
right_scale = load_weight_shard(full_right_scale, module.tp_size,
643663
module.tp_rank, module.tp_mode)
644-
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
664+
fused_scale = torch.cat([left_scale, right_scale], dim=0)
645665
copy_weight(module.weight, fused_weight)
646666
copy_weight(module.weight_scale, fused_scale)
647667

0 commit comments

Comments
 (0)