-
Notifications
You must be signed in to change notification settings - Fork 417
[Draft] Integrate TE/JAX GroupedGEMM for MoE. #2319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
21119d7
28aac5b
2d5a83c
f44c417
c502141
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright 2023–2025 Google LLC | ||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
|
@@ -123,6 +124,29 @@ def random_routing(rng_key, gate_logits, num_experts_per_tok): | |
return top_k_weights, top_k_indices | ||
|
||
|
||
# This is WAR to avoid no backward impl error from pmax. | ||
def pmax_no_backward(x, axis): | ||
return _pmax_no_backward(x, axis) | ||
|
||
|
||
@functools.partial(jax.custom_vjp, nondiff_argnums=(1,)) | ||
def _pmax_no_backward(x, axis): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wondering if you have tested the gradient? We have some examples in moe_test.py There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have not tested the unit-tests, but we tested on end-2-end deepseek3 model training. The loss curve match with native XLA impl. Will attach to the description shortly. |
||
x_max, _ = _pmax_no_backward_fwd_rule(x, axis) | ||
return x_max | ||
|
||
|
||
def _pmax_no_backward_fwd_rule(x, axis): | ||
x_max = jax.lax.pmax(x, axis) | ||
return x_max, None | ||
|
||
|
||
def _pmax_no_backward_bwd_rule(axis, ctx, grad): | ||
del axis, ctx | ||
return grad, | ||
|
||
_pmax_no_backward.defvjp(_pmax_no_backward_fwd_rule, _pmax_no_backward_bwd_rule) | ||
|
||
|
||
class GateLogit(nnx.Module): | ||
"""A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing.""" | ||
|
||
|
@@ -1401,6 +1425,55 @@ def maybe_all_gather_kernel_weight_in_expert_parallelism( | |
kernel = nn.with_logical_constraint(kernel, kernel_axes) | ||
return kernel | ||
|
||
def run_te_grouped_gemm(self, inputs, kernels, | ||
*, | ||
in_specs, out_specs, | ||
w_fsdp_axis, w_fsdp_dim): | ||
from transformer_engine.jax.dense import grouped_dense | ||
from transformer_engine.jax.quantize import ( | ||
ScalingMode, | ||
QuantizerFactory | ||
) | ||
|
||
|
||
@functools.partial( | ||
shard_map.shard_map, | ||
mesh=self.mesh, | ||
in_specs=in_specs, | ||
out_specs=out_specs, | ||
check_rep=False, | ||
) | ||
def sharded_group_gemm(x, w): | ||
|
||
group_size = x.shape[0] | ||
x_reshaped = x.reshape(-1, x.shape[-1]) | ||
n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) | ||
|
||
kernel_amax = pmax_no_backward(jnp.max(jnp.abs(w), axis=range(1, w.ndim)), w_fsdp_axis) | ||
kernel_fsdp_info = (w_fsdp_axis, w_fsdp_dim) | ||
|
||
if self.quant: | ||
quantizer_set = QuantizerFactory.create_set( | ||
scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, | ||
fwd_dtype=jnp.float8_e4m3fn, | ||
bwd_dtype=jnp.float8_e5m2, | ||
n_groups=group_size | ||
) | ||
else: | ||
quantizer_set = None | ||
|
||
output = grouped_dense( | ||
x_reshaped, w, n_groups, | ||
kernel_amax=kernel_amax, | ||
quantizer_set=quantizer_set, | ||
kernel_fsdp_info=kernel_fsdp_info, | ||
) | ||
output = output.reshape(*x.shape[:-1], -1) | ||
|
||
return output | ||
|
||
return sharded_group_gemm(inputs, kernels) | ||
|
||
def dense_matmul( | ||
self, | ||
inputs, | ||
|
@@ -1554,12 +1627,25 @@ def dense_matmul( | |
dispatch, | ||
dispatch_axis, | ||
) | ||
|
||
|
||
dispatch_pspec = nn.logical_to_mesh_axes(dispatch_axis) | ||
w0w1_kernel_psepc = nn.logical_to_mesh_axes(("exp", "embed_no_exp", "mlp")) | ||
layer_w0w1_psepc = nn.logical_to_mesh_axes(mlp_axis) | ||
with jax.named_scope("wi_0"): | ||
w0_kernel_axes = ("exp", None, "mlp") | ||
w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes) | ||
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( | ||
mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision | ||
) | ||
if self.config.te_grouped_gemm: | ||
layer_w0 = self.run_te_grouped_gemm(dispatch, w0_kernel, | ||
in_specs=(dispatch_pspec, w0w1_kernel_psepc,), | ||
out_specs=layer_w0w1_psepc, | ||
w_fsdp_axis="fsdp", w_fsdp_dim=1) | ||
else: | ||
# TODO(ranran): check if None should be replaced by "embed_no_exp" for performance. | ||
w0_kernel_axes = ("exp", None, "mlp") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't recall the reason we put "None" in the embedding dimension. But I think we are good to reuse. Could you help add a comment: "TODO(ranran): check if None should be replaced by "embed_no_exp" for performance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes) | ||
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( | ||
mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision | ||
) | ||
|
||
if self.config.mlp_bias: | ||
w0_bias = w0_bias[:, None, None, :] | ||
layer_w0 = layer_w0 + w0_bias | ||
|
@@ -1572,14 +1658,23 @@ def dense_matmul( | |
) | ||
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0") | ||
with jax.named_scope("wi_1"): | ||
w1_kernel_axes = ("exp", None, "mlp") | ||
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes) | ||
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( | ||
mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision | ||
) | ||
if self.config.te_grouped_gemm: | ||
layer_w1 = self.run_te_grouped_gemm(dispatch, w1_kernel, | ||
in_specs=(dispatch_pspec, w0w1_kernel_psepc,), | ||
out_specs=layer_w0w1_psepc, | ||
w_fsdp_axis="fsdp", w_fsdp_dim=1) | ||
else: | ||
# TODO(ranran): check if None should be replaced by "embed_no_exp" for performance. | ||
w1_kernel_axes = ("exp", None, "mlp") | ||
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes) | ||
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( | ||
mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision | ||
) | ||
|
||
if self.config.mlp_bias: | ||
w1_bias = w1_bias[:, None, None, :] | ||
layer_w1 = layer_w1 + w1_bias | ||
|
||
if self.config.activations_in_float32: | ||
layer_w1 = layer_w1.astype(jnp.float32) | ||
layer_w1 = nn.with_logical_constraint( | ||
|
@@ -1589,17 +1684,26 @@ def dense_matmul( | |
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1") | ||
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1) | ||
with jax.named_scope("wo"): | ||
wo_kernel_axes = ("exp", "mlp", None) | ||
wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes) | ||
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( | ||
mlp_down_einsum, | ||
layer_multiply, | ||
wo_kernel, | ||
precision=matmul_precision, | ||
) | ||
if self.config.te_grouped_gemm: | ||
intermediate_layer = self.run_te_grouped_gemm(layer_multiply, wo_kernel, | ||
in_specs=(layer_multiply_pspec, wo_kernel_psepc,), | ||
out_specs=intermediate_layer_psepc, | ||
w_fsdp_axis="fsdp", w_fsdp_dim=2) | ||
else: | ||
# TODO(ranran): check if None should be replaced by "embed_no_exp" for performance. | ||
wo_kernel_axes = ("exp", "mlp", None) | ||
wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes) | ||
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( | ||
mlp_down_einsum, | ||
layer_multiply, | ||
wo_kernel, | ||
precision=matmul_precision, | ||
) | ||
|
||
if self.config.mlp_bias: | ||
wo_bias = wo_bias[:, None, None, :] | ||
intermediate_layer = intermediate_layer + wo_bias | ||
|
||
if self.config.activations_in_float32: | ||
intermediate_layer = intermediate_layer.astype(jnp.float32) | ||
if self.config.model_call_mode != "inference": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kyle-meggs @ultrons - I have no idea what our copyright policies are, curious what your thoughts are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on this copyright comment