Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2023–2025 Google LLC
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kyle-meggs @ultrons - I have no idea what our copyright policies are, curious what your thoughts are

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this copyright comment

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -164,6 +165,7 @@ num_experts: 1
num_experts_per_tok: 1
megablox: True
sparse_matmul: True
te_grouped_gemm: False
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
load_balance_loss_weight: 0.01 # weight for the load balance loss
use_random_routing: False # whether to use random routing for debug/test purpose
Expand Down
142 changes: 124 additions & 18 deletions src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2023–2025 Google LLC
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -123,6 +124,29 @@ def random_routing(rng_key, gate_logits, num_experts_per_tok):
return top_k_weights, top_k_indices


# This is WAR to avoid no backward impl error from pmax.
def pmax_no_backward(x, axis):
return _pmax_no_backward(x, axis)


@functools.partial(jax.custom_vjp, nondiff_argnums=(1,))
def _pmax_no_backward(x, axis):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if you have tested the gradient? We have some examples in moe_test.py

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not tested the unit-tests, but we tested on end-2-end deepseek3 model training. The loss curve match with native XLA impl. Will attach to the description shortly.

x_max, _ = _pmax_no_backward_fwd_rule(x, axis)
return x_max


def _pmax_no_backward_fwd_rule(x, axis):
x_max = jax.lax.pmax(x, axis)
return x_max, None


def _pmax_no_backward_bwd_rule(axis, ctx, grad):
del axis, ctx
return grad,

_pmax_no_backward.defvjp(_pmax_no_backward_fwd_rule, _pmax_no_backward_bwd_rule)


class GateLogit(nnx.Module):
"""A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing."""

Expand Down Expand Up @@ -1307,6 +1331,56 @@ def maybe_all_gather_kernel_weight_in_expert_parallelism(
kernel = nn.with_logical_constraint(kernel, kernel_axes)
return kernel

def run_te_grouped_gemm(self, inputs, kernels,
*,
in_specs, out_specs,
w_fsdp_axis, w_fsdp_dim):
from transformer_engine.jax.dense import grouped_dense
from transformer_engine.jax.quantize import (
ScalingMode,
QuantizerFactory
)


@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
)
def sharded_group_gemm(x, w):

group_size = x.shape[0]
x_reshaped = x.reshape(-1, x.shape[-1])
n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size)

kernel_amax = pmax_no_backward(jnp.max(jnp.abs(w), axis=range(1, w.ndim)), w_fsdp_axis)
kernel_fsdp_info = (w_fsdp_axis, w_fsdp_dim)

if self.quant:
quantizer_set = QuantizerFactory.create_set(
scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2,
is_2x2x=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, what does this 2x2x mean?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2x2x = True means outputting 2 copies of quantized data, one is used in the forward, and one is used in the backward.

@mingxu1067 I think you should not specify is_2x2x here so that the program can decide based on the GPU arch, i.e. only doing 2x on Hopper https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/quantize/quantizer.py#L969

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

n_groups=group_size
)
else:
quantizer_set = None

output = grouped_dense(
x_reshaped, w, n_groups,
kernel_amax=kernel_amax,
quantizer_set=quantizer_set,
kernel_fsdp_info=kernel_fsdp_info,
)
output = output.reshape(*x.shape[:-1], -1)

return output

return sharded_group_gemm(inputs, kernels)

def dense_matmul(
self,
inputs,
Expand Down Expand Up @@ -1460,12 +1534,24 @@ def dense_matmul(
dispatch,
dispatch_axis,
)


dispatch_pspec = nn.logical_to_mesh_axes(dispatch_axis)
w0w1_kernel_psepc = nn.logical_to_mesh_axes(("exp", "embed_no_exp", "mlp"))
layer_w0w1_psepc = nn.logical_to_mesh_axes(mlp_axis)
with jax.named_scope("wi_0"):
w0_kernel_axes = ("exp", None, "mlp")
w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes)
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)(
mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision
)
if self.config.te_grouped_gemm:
layer_w0 = self.run_te_grouped_gemm(dispatch, w0_kernel,
in_specs=(dispatch_pspec, w0w1_kernel_psepc,),
out_specs=layer_w0w1_psepc,
w_fsdp_axis="fsdp", w_fsdp_dim=1)
else:
w0_kernel_axes = ("exp", None, "mlp")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall the reason we put "None" in the embedding dimension. But I think we are good to reuse. Could you help add a comment: "TODO(ranran): check if None should be replaced by "embed_no_exp" for performance.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes)
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)(
mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision
)

if self.config.mlp_bias:
w0_bias = w0_bias[:, None, None, :]
layer_w0 = layer_w0 + w0_bias
Expand All @@ -1478,14 +1564,22 @@ def dense_matmul(
)
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
with jax.named_scope("wi_1"):
w1_kernel_axes = ("exp", None, "mlp")
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)(
mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision
)
if self.config.te_grouped_gemm:
layer_w1 = self.run_te_grouped_gemm(dispatch, w1_kernel,
in_specs=(dispatch_pspec, w0w1_kernel_psepc,),
out_specs=layer_w0w1_psepc,
w_fsdp_axis="fsdp", w_fsdp_dim=1)
else:
w1_kernel_axes = ("exp", None, "mlp")
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)(
mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision
)

if self.config.mlp_bias:
w1_bias = w1_bias[:, None, None, :]
layer_w1 = layer_w1 + w1_bias

if self.config.activations_in_float32:
layer_w1 = layer_w1.astype(jnp.float32)
layer_w1 = nn.with_logical_constraint(
Expand All @@ -1496,18 +1590,30 @@ def dense_matmul(
# pylint: disable=protected-access
layer_w0_act = linears._convert_to_activation_function(self.config.mlp_activations[0])(layer_w0)
layer_multiply = jnp.multiply(layer_w0_act, layer_w1).astype(self.dtype)

layer_multiply_pspec = layer_w0w1_psepc
wo_kernel_psepc = nn.logical_to_mesh_axes(("exp", "mlp", "embed_no_exp"))
intermediate_layer_psepc = nn.logical_to_mesh_axes(("activation_exp", "activation_batch_no_exp", None, "activation_embed"))
with jax.named_scope("wo"):
wo_kernel_axes = ("exp", "mlp", None)
wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes)
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)(
mlp_down_einsum,
layer_multiply,
wo_kernel,
precision=matmul_precision,
)
if self.config.te_grouped_gemm:
intermediate_layer = self.run_te_grouped_gemm(layer_multiply, wo_kernel,
in_specs=(layer_multiply_pspec, wo_kernel_psepc,),
out_specs=intermediate_layer_psepc,
w_fsdp_axis="fsdp", w_fsdp_dim=2)
else:
wo_kernel_axes = ("exp", "mlp", None)
wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes)
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)(
mlp_down_einsum,
layer_multiply,
wo_kernel,
precision=matmul_precision,
)

if self.config.mlp_bias:
wo_bias = wo_bias[:, None, None, :]
intermediate_layer = intermediate_layer + wo_bias

if self.config.activations_in_float32:
intermediate_layer = intermediate_layer.astype(jnp.float32)
if self.config.model_call_mode != "inference":
Expand Down