Skip to content

Commit 66ac814

Browse files
committed
fix deepseek tp sharding error
1 parent 025a0f6 commit 66ac814

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/MaxText/layers/moe.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import jax.numpy as jnp
3030
import numpy as np
3131

32-
from MaxText import common_types as ctypes
32+
from MaxText import common_types as ctypes, EP_AS_CONTEXT
3333
from MaxText import max_logging
3434
from MaxText import max_utils
3535
from MaxText.kernels import megablox as mblx
@@ -1833,8 +1833,18 @@ def routed_moe(self):
18331833
return self.MoeBlock_0
18341834

18351835
def __call__(self, inputs: jax.Array) -> jax.Array:
1836+
batch_logical_axes = (
1837+
"activation_batch_no_exp" if self.config.expert_shard_attention_option == EP_AS_CONTEXT
1838+
else "activation_batch"
1839+
)
1840+
seq_logical_axes = (
1841+
"activation_length" if self.config.expert_shard_attention_option == EP_AS_CONTEXT
1842+
else "activation_length_no_exp"
1843+
)
18361844
routed_experts, _ = self.routed_moe(inputs)
1845+
routed_experts = nn.with_logical_constraint(routed_experts, (batch_logical_axes, seq_logical_axes, "activation_embed"))
18371846
shared_experts = self.shared_experts(inputs)
1847+
shared_experts = nn.with_logical_constraint(shared_experts, (batch_logical_axes, seq_logical_axes, "activation_embed"))
18381848
return routed_experts + shared_experts
18391849

18401850

0 commit comments

Comments
 (0)