Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2023–2025 Google LLC
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kyle-meggs @ultrons - I have no idea what our copyright policies are, curious what your thoughts are

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this copyright comment

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -164,6 +165,7 @@ num_experts: 1
num_experts_per_tok: 1
megablox: True
sparse_matmul: True
te_grouped_gemm: False
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
load_balance_loss_weight: 0.01 # weight for the load balance loss
use_random_routing: False # whether to use random routing for debug/test purpose
Expand Down
140 changes: 122 additions & 18 deletions src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2023–2025 Google LLC
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -123,6 +124,29 @@ def random_routing(rng_key, gate_logits, num_experts_per_tok):
return top_k_weights, top_k_indices


# This is WAR to avoid no backward impl error from pmax.
def pmax_no_backward(x, axis):
return _pmax_no_backward(x, axis)


@functools.partial(jax.custom_vjp, nondiff_argnums=(1,))
def _pmax_no_backward(x, axis):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if you have tested the gradient? We have some examples in moe_test.py

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not tested the unit-tests, but we tested on end-2-end deepseek3 model training. The loss curve match with native XLA impl. Will attach to the description shortly.

x_max, _ = _pmax_no_backward_fwd_rule(x, axis)
return x_max


def _pmax_no_backward_fwd_rule(x, axis):
x_max = jax.lax.pmax(x, axis)
return x_max, None


def _pmax_no_backward_bwd_rule(axis, ctx, grad):
del axis, ctx
return grad,

_pmax_no_backward.defvjp(_pmax_no_backward_fwd_rule, _pmax_no_backward_bwd_rule)


class GateLogit(nnx.Module):
"""A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing."""

Expand Down Expand Up @@ -1401,6 +1425,55 @@ def maybe_all_gather_kernel_weight_in_expert_parallelism(
kernel = nn.with_logical_constraint(kernel, kernel_axes)
return kernel

def run_te_grouped_gemm(self, inputs, kernels,
*,
in_specs, out_specs,
w_fsdp_axis, w_fsdp_dim):
from transformer_engine.jax.dense import grouped_dense
from transformer_engine.jax.quantize import (
ScalingMode,
QuantizerFactory
)


@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
)
def sharded_group_gemm(x, w):

group_size = x.shape[0]
x_reshaped = x.reshape(-1, x.shape[-1])
n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size)

kernel_amax = pmax_no_backward(jnp.max(jnp.abs(w), axis=range(1, w.ndim)), w_fsdp_axis)
kernel_fsdp_info = (w_fsdp_axis, w_fsdp_dim)

if self.quant:
quantizer_set = QuantizerFactory.create_set(
scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2,
n_groups=group_size
)
else:
quantizer_set = None

output = grouped_dense(
x_reshaped, w, n_groups,
kernel_amax=kernel_amax,
quantizer_set=quantizer_set,
kernel_fsdp_info=kernel_fsdp_info,
)
output = output.reshape(*x.shape[:-1], -1)

return output

return sharded_group_gemm(inputs, kernels)

def dense_matmul(
self,
inputs,
Expand Down Expand Up @@ -1554,12 +1627,25 @@ def dense_matmul(
dispatch,
dispatch_axis,
)


dispatch_pspec = nn.logical_to_mesh_axes(dispatch_axis)
w0w1_kernel_psepc = nn.logical_to_mesh_axes(("exp", "embed_no_exp", "mlp"))
layer_w0w1_psepc = nn.logical_to_mesh_axes(mlp_axis)
with jax.named_scope("wi_0"):
w0_kernel_axes = ("exp", None, "mlp")
w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes)
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)(
mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision
)
if self.config.te_grouped_gemm:
layer_w0 = self.run_te_grouped_gemm(dispatch, w0_kernel,
in_specs=(dispatch_pspec, w0w1_kernel_psepc,),
out_specs=layer_w0w1_psepc,
w_fsdp_axis="fsdp", w_fsdp_dim=1)
else:
# TODO(ranran): check if None should be replaced by "embed_no_exp" for performance.
w0_kernel_axes = ("exp", None, "mlp")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall the reason we put "None" in the embedding dimension. But I think we are good to reuse. Could you help add a comment: "TODO(ranran): check if None should be replaced by "embed_no_exp" for performance.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes)
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)(
mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision
)

if self.config.mlp_bias:
w0_bias = w0_bias[:, None, None, :]
layer_w0 = layer_w0 + w0_bias
Expand All @@ -1572,14 +1658,23 @@ def dense_matmul(
)
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
with jax.named_scope("wi_1"):
w1_kernel_axes = ("exp", None, "mlp")
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)(
mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision
)
if self.config.te_grouped_gemm:
layer_w1 = self.run_te_grouped_gemm(dispatch, w1_kernel,
in_specs=(dispatch_pspec, w0w1_kernel_psepc,),
out_specs=layer_w0w1_psepc,
w_fsdp_axis="fsdp", w_fsdp_dim=1)
else:
# TODO(ranran): check if None should be replaced by "embed_no_exp" for performance.
w1_kernel_axes = ("exp", None, "mlp")
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)(
mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision
)

if self.config.mlp_bias:
w1_bias = w1_bias[:, None, None, :]
layer_w1 = layer_w1 + w1_bias

if self.config.activations_in_float32:
layer_w1 = layer_w1.astype(jnp.float32)
layer_w1 = nn.with_logical_constraint(
Expand All @@ -1589,17 +1684,26 @@ def dense_matmul(
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
with jax.named_scope("wo"):
wo_kernel_axes = ("exp", "mlp", None)
wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes)
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)(
mlp_down_einsum,
layer_multiply,
wo_kernel,
precision=matmul_precision,
)
if self.config.te_grouped_gemm:
intermediate_layer = self.run_te_grouped_gemm(layer_multiply, wo_kernel,
in_specs=(layer_multiply_pspec, wo_kernel_psepc,),
out_specs=intermediate_layer_psepc,
w_fsdp_axis="fsdp", w_fsdp_dim=2)
else:
# TODO(ranran): check if None should be replaced by "embed_no_exp" for performance.
wo_kernel_axes = ("exp", "mlp", None)
wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes)
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)(
mlp_down_einsum,
layer_multiply,
wo_kernel,
precision=matmul_precision,
)

if self.config.mlp_bias:
wo_bias = wo_bias[:, None, None, :]
intermediate_layer = intermediate_layer + wo_bias

if self.config.activations_in_float32:
intermediate_layer = intermediate_layer.astype(jnp.float32)
if self.config.model_call_mode != "inference":
Expand Down