Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xllm/core/framework/model/model_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ struct ModelArgs {
PROPERTY(bool, norm_topk_prob) = false;
PROPERTY(int32_t, n_group) = 0;
PROPERTY(int32_t, topk_group) = 0;
PROPERTY(std::string, scoring_func);
// deepseek v2/v3 MLA
PROPERTY(int32_t, qk_nope_head_dim) = 0;
PROPERTY(int32_t, qk_rope_head_dim) = 0;
Expand Down
4 changes: 4 additions & 0 deletions xllm/core/layers/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ cc_library(
word_embedding_impl.h
layer_utils.h
indexer.h
deepseek_v2_attention.h
deepseek_v2_decoder_layer.h
SRCS
flashinfer_workspace.cpp
deepseek_v2_attention.cpp
Expand All @@ -33,6 +35,8 @@ cc_library(
word_embedding_impl.cpp
layer_utils.cpp
indexer.cpp
deepseek_v2_attention.cpp
deepseek_v2_decoder_layer.cpp
DEPS
"-Wl,--whole-archive"
"-Wl,--no-whole-archive"
Expand Down
1 change: 1 addition & 0 deletions xllm/core/layers/common/deepseek_v2_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <tuple>

#include "kernels/ops_api.h"

namespace xllm {
namespace layer {

Expand Down
136 changes: 136 additions & 0 deletions xllm/core/layers/common/deepseek_v2_decoder_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "deepseek_v2_decoder_layer.h"

namespace xllm {
namespace layer {

DeepseekV2DecoderImpl::DeepseekV2DecoderImpl(const ModelContext& context,
int32_t layer_id)
: parallel_args_(context.get_parallel_args()) {
const auto& model_args = context.get_model_args();
const auto& quant_args = context.get_quant_args();
const auto& options = context.get_tensor_options();

// get rank and world_size from parallel_args_
rank_ = parallel_args_.rank();
world_size_ = parallel_args_.world_size();

// Initialize attention layers
attention_ = register_module(
"self_attn",
DeepseekV2Attention(model_args, quant_args, parallel_args_, options));

// Initialize norm layers
input_norm_ = register_module(
"input_layernorm",
RmsNorm(model_args.hidden_size(), model_args.rms_norm_eps(), options));

post_norm_ = register_module(
"post_attention_layernorm",
RmsNorm(model_args.hidden_size(), model_args.rms_norm_eps(), options));

// Initialize mlp
auto first_k_dense_replace = model_args.first_k_dense_replace();
if (layer_id >= first_k_dense_replace) {
moe_mlp_ = register_module("mlp",
FusedMoE(model_args.n_routed_experts(),
model_args.num_experts_per_tok(),
model_args.n_group(),
model_args.topk_group(),
model_args.routed_scaling_factor(),
model_args.hidden_size(),
model_args.moe_intermediate_size(),
model_args.n_shared_experts(),
/*is_gated=*/true,
/*has_score_bias=*/false,
/*has_bias=*/false,
/*skip_bias_add=*/false,
model_args.norm_topk_prob(),
model_args.hidden_act(),
model_args.scoring_func(),
model_args.topk_method(),
quant_args,
parallel_args_,
options));
} else {
mlp_ = register_module("mlp",
DenseMLP(model_args.hidden_size(),
model_args.intermediate_size(),
/*is_gated=*/true,
/*has_bias=*/false,
model_args.hidden_act(),
/*enable_result_reduction=*/true,
quant_args,
parallel_args_,
options));
}
}

void DeepseekV2DecoderImpl::load_state_dict(const StateDict& state_dict) {
attention_->load_state_dict(state_dict.get_dict_with_prefix("self_attn."));
input_norm_->load_state_dict(
state_dict.get_dict_with_prefix("input_layernorm."));
post_norm_->load_state_dict(
state_dict.get_dict_with_prefix("post_attention_layernorm."));
if (moe_mlp_) {
moe_mlp_->load_state_dict(state_dict.get_dict_with_prefix("mlp."));
} else {
mlp_->load_state_dict(state_dict.get_dict_with_prefix("mlp."));
}
}

torch::Tensor DeepseekV2DecoderImpl::forward(
torch::Tensor& x,
torch::Tensor& positions,
const AttentionMetadata& attn_metadata,
KVCache& kv_cache,
const ModelInputParams& input_params) {
// Input norm
torch::Tensor residual = x;
x = input_norm_(x);

// Attention
x = attention_->forward(positions, x, attn_metadata, kv_cache);

// add tensor model group all reduce
// to avoid implicit communcation in deepseek attention layer.
if (world_size_ > 1) {
x = xllm::parallel_state::reduce(x, parallel_args_.tp_group_);
}

// add up residual before post norm
x = x + residual;

// Post-attention norm
residual = x;
x = post_norm_(x);

// MLP forward
if (moe_mlp_) {
x = moe_mlp_(x, input_params);
} else {
x = mlp_(x);
}

// add up residual after mlp/moe
x = x + residual;

return x;
}

} // namespace layer
} // namespace xllm
65 changes: 65 additions & 0 deletions xllm/core/layers/common/deepseek_v2_decoder_layer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <torch/torch.h>

#include "attention.h"
#include "deepseek_v2_attention.h"
#include "dense_mlp.h"
#include "framework/kv_cache/kv_cache.h"
#include "framework/model/model_args.h"
#include "framework/model/model_input_params.h"
#include "framework/model_context.h"
#include "framework/parallel_state/parallel_args.h"
#include "framework/parallel_state/parallel_state.h"
#include "framework/quant_args.h"
#include "framework/state_dict/state_dict.h"
#include "fused_moe.h"
#include "layers/rms_norm.h"

namespace xllm {
namespace layer {

class DeepseekV2DecoderImpl : public torch::nn::Module {
public:
explicit DeepseekV2DecoderImpl(const ModelContext& context, int32_t layer_id);

~DeepseekV2DecoderImpl() {};

void load_state_dict(const StateDict& state_dict);

torch::Tensor forward(torch::Tensor& x,
torch::Tensor& positions,
const AttentionMetadata& attn_metadata,
KVCache& kv_cache,
const ModelInputParams& input_params);

private:
// parallel args
int64_t rank_;
int64_t world_size_;
ParallelArgs parallel_args_;

DeepseekV2Attention attention_{nullptr};
DenseMLP mlp_{nullptr};
FusedMoE moe_mlp_{nullptr};
RmsNorm input_norm_{nullptr};
RmsNorm post_norm_{nullptr};
};

} // namespace layer
} // namespace xllm
3 changes: 2 additions & 1 deletion xllm/core/layers/common/dense_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ DenseMLPImpl::DenseMLPImpl(int64_t hidden_size,
bool is_gated,
bool has_bias,
const std::string& hidden_act,
bool enable_result_reduction,
const QuantArgs& quant_args,
const ParallelArgs& parallel_args,
const torch::TensorOptions& options)
Expand Down Expand Up @@ -73,7 +74,7 @@ DenseMLPImpl::DenseMLPImpl(int64_t hidden_size,
hidden_size,
/*bias=*/has_bias,
/*input_is_parallelized=*/true,
/*if_reduce_results=*/true,
enable_result_reduction,
quant_args,
parallel_args,
options,
Expand Down
1 change: 1 addition & 0 deletions xllm/core/layers/common/dense_mlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class DenseMLPImpl : public torch::nn::Module {
bool is_gated,
bool has_bias,
const std::string& hidden_act,
bool enable_result_reduction,
const QuantArgs& quant_args,
const ParallelArgs& parallel_args,
const torch::TensorOptions& options);
Expand Down
8 changes: 8 additions & 0 deletions xllm/core/layers/common/fused_moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,21 @@ FusedMoEImpl::FusedMoEImpl(int64_t num_experts,
"gate_proj",
ReplicatedLinear(hidden_size, num_experts, false, quant_args, options));
if (n_shared_experts_ > 0) {
/*
The shared_experts are usually implemented using the RowParallelLinear
layer. Typically, this output serves as the enable_result_reduction results
for the module. If only tensor parallelism is applied, immediate
reduction of the shared_experts output isn't necessary; instead, we perform
the reduction once at the end of the MoE operation.
*/
shared_experts_ =
register_module("shared_experts",
DenseMLP(hidden_size,
intermediate_size * n_shared_experts_,
is_gated_,
false,
hidden_act_,
/*enable_result_reduction=*/false,
quant_args,
parallel_args,
options));
Expand Down
6 changes: 3 additions & 3 deletions xllm/core/layers/common/linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,13 @@ RowParallelLinearImpl::RowParallelLinearImpl(
int64_t out_features,
bool bias,
bool input_is_parallelized,
bool if_reduce_results,
bool enable_result_reduction,
const QuantArgs& quant_args,
const ParallelArgs& parallel_args,
const torch::TensorOptions& options,
const FusedLinearExtraArgs& linear_extra_args)
: input_is_parallelized_(input_is_parallelized),
if_reduce_results_(if_reduce_results),
enable_result_reduction_(enable_result_reduction),
quant_args_(quant_args),
parallel_args_(parallel_args),
linear_extra_args_(linear_extra_args) {
Expand Down Expand Up @@ -377,7 +377,7 @@ torch::Tensor RowParallelLinearImpl::forward(torch::Tensor input) {
matmul_params.bias = bias;
output = xllm::kernel::matmul(matmul_params);
}
if (if_reduce_results_ && world_size_ > 1) {
if (enable_result_reduction_ && world_size_ > 1) {
output = xllm::parallel_state::reduce(output, parallel_args_.tp_group_);
}
return output;
Expand Down
4 changes: 2 additions & 2 deletions xllm/core/layers/common/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class RowParallelLinearImpl : public torch::nn::Module {
int64_t out_features,
bool bias,
bool input_is_parallelized,
bool if_reduce_results,
bool enable_result_reduction,
const QuantArgs& quant_args,
const ParallelArgs& parallel_args,
const torch::TensorOptions& options,
Expand Down Expand Up @@ -192,7 +192,7 @@ class RowParallelLinearImpl : public torch::nn::Module {
bool input_is_parallelized_;

// whether to reduce the results
bool if_reduce_results_;
bool enable_result_reduction_;

// parallel args
ParallelArgs parallel_args_;
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/layers/common/qwen2_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Qwen2AttentionImpl::Qwen2AttentionImpl(const ModelContext& context) {
args.hidden_size(),
/*bias=*/false,
/*input_is_parallelized=*/true,
/*if_reduce_results=*/true,
/*enable_result_reduction=*/true,
quant_args,
parallel_args,
options));
Expand Down
1 change: 1 addition & 0 deletions xllm/core/layers/common/qwen2_decoder_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Qwen2DecoderImpl::Qwen2DecoderImpl(const ModelContext& context)
true,
false,
model_args.hidden_act(),
/*enable_result_reduction=*/true,
quant_args,
parallel_args,
options));
Expand Down
1 change: 1 addition & 0 deletions xllm/core/layers/common/qwen3_moe_decoder_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Qwen3MoeDecoderImpl::Qwen3MoeDecoderImpl(const ModelContext& context,
true,
false,
model_args.hidden_act(),
/*enable_result_reduction=*/true,
quant_args,
parallel_args,
options));
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/layers/common/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ cc_test(
fused_moe_tests.cpp
indexer_tests.cpp
mla_tests.cpp
deepseek_v2_decoder_layer_tests.cpp
tests_utils.cpp
DEPS
:common_layers
:parallel_state
:model
:model_context
:state_dict
glog::glog
torch
Expand Down
Loading
Loading