diff --git a/xllm/core/framework/model_context.h b/xllm/core/framework/model_context.h index 0da1b7008..4e6fbf29b 100644 --- a/xllm/core/framework/model_context.h +++ b/xllm/core/framework/model_context.h @@ -55,6 +55,15 @@ class ModelContext { return tensor_options_; } + void set_layer_id(int32_t layer_id) { layer_id_ = layer_id; } + const int32_t layer_id() const { + if (layer_id_ == -1) { + LOG(ERROR) << "layer_id is not set in ModelContext, layer_id_ = " + << layer_id_; + } + return layer_id_; + } + #if defined(USE_NPU) const atb::Context* get_atb_context() const { return context_; } #endif @@ -64,6 +73,7 @@ class ModelContext { } private: + int32_t layer_id_ = -1; ModelArgs model_args_; QuantArgs quant_args_; ParallelArgs parallel_args_; diff --git a/xllm/core/kernels/cuda/batch_decode.cpp b/xllm/core/kernels/cuda/batch_decode.cpp index e6aafce83..6f3d219ca 100644 --- a/xllm/core/kernels/cuda/batch_decode.cpp +++ b/xllm/core/kernels/cuda/batch_decode.cpp @@ -18,7 +18,9 @@ limitations under the License. namespace xllm::kernel::cuda { -void batch_decode(torch::Tensor float_workspace_buffer, +void batch_decode(const std::string& uri, + torch::Tensor plan_info, + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, torch::Tensor query, @@ -32,41 +34,6 @@ void batch_decode(torch::Tensor float_workspace_buffer, torch::Tensor output, std::optional& output_lse, bool enable_cuda_graph) { - std::string uri = get_batch_decode_uri(query.scalar_type(), - k_cache.scalar_type(), - output.scalar_type(), - paged_kv_indptr.scalar_type(), - query.size(-1), - v_cache.size(-1), - /*pos_encoding_mode=*/0, - /*use_sliding_window=*/false, - /*use_logits_soft_cap=*/false); - - torch::Tensor paged_kv_indptr_host = paged_kv_indptr.to(torch::kCPU); - const int64_t batch_size = paged_kv_last_page_len.size(0); - - torch::Tensor empty_q_data = - torch::empty({0}, torch::TensorOptions().dtype(query.scalar_type())); - torch::Tensor empty_kv_data = - torch::empty({0}, torch::TensorOptions().dtype(k_cache.scalar_type())); - - auto plan_info = FunctionFactory::get_instance().decode_plan_func(uri).call( - float_workspace_buffer, - int_workspace_buffer, - page_locked_int_workspace_buffer, - paged_kv_indptr_host, - batch_size, - query.size(1), // num_qo_heads - k_cache.size(2), // num_kv_heads - k_cache.size(1), // block_size - enable_cuda_graph, - window_left, - /*logits_soft_cap=*/0.0, - query.size(-1), // head_dim_qk - v_cache.size(-1), // head_dim_vo - empty_q_data, - empty_kv_data); - FunctionFactory::get_instance().decode_run_func(uri).call( float_workspace_buffer, int_workspace_buffer, diff --git a/xllm/core/kernels/cuda/batch_prefill.cpp b/xllm/core/kernels/cuda/batch_prefill.cpp index edf568d59..62d81694e 100644 --- a/xllm/core/kernels/cuda/batch_prefill.cpp +++ b/xllm/core/kernels/cuda/batch_prefill.cpp @@ -18,7 +18,9 @@ limitations under the License. namespace xllm::kernel::cuda { -void batch_prefill(torch::Tensor float_workspace_buffer, +void batch_prefill(const std::string& uri, + torch::Tensor plan_info, + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, torch::Tensor query, @@ -35,43 +37,6 @@ void batch_prefill(torch::Tensor float_workspace_buffer, determine_attention_backend(/*pos_encoding_mode=*/0, /*use_fp16_qk_reduction=*/false, /*use_custom_mask=*/false); - - std::string uri = get_batch_prefill_uri(backend, - query.scalar_type(), - key.scalar_type(), - output.scalar_type(), - q_cu_seq_lens.scalar_type(), - query.size(-1), - value.size(-1), - /*pos_encoding_mode=*/0, - /*use_sliding_window=*/false, - /*use_logits_soft_cap=*/false, - /*use_fp16_qk_reduction=*/false); - - torch::Tensor qo_indptr_host = q_cu_seq_lens.to(torch::kCPU); - torch::Tensor kv_cu_seq_lens_host = kv_cu_seq_lens.to(torch::kCPU); - torch::Tensor kv_len_arr_host = - kv_cu_seq_lens_host.slice(0, 1) - kv_cu_seq_lens_host.slice(0, 0, -1); - const int64_t total_num_rows = qo_indptr_host[-1].item(); - const int64_t batch_size = qo_indptr_host.size(0) - 1; - - auto plan_info = FunctionFactory::get_instance().prefill_plan_func(uri).call( - float_workspace_buffer, - int_workspace_buffer, - page_locked_int_workspace_buffer, - qo_indptr_host, - kv_cu_seq_lens_host, - kv_len_arr_host, - total_num_rows, - batch_size, - query.size(1), // num_qo_heads - key.size(1), // num_kv_heads - /*page_size=*/1, - enable_cuda_graph, - query.size(-1), // head_dim_qk - value.size(-1), // head_dim_vo - /*causal=*/true); - if (backend == "fa2") { FunctionFactory::get_instance().fa2_prefill_ragged_run_func(uri).call( float_workspace_buffer, diff --git a/xllm/core/kernels/cuda/cuda_ops_api.h b/xllm/core/kernels/cuda/cuda_ops_api.h index 7125fa518..137115200 100644 --- a/xllm/core/kernels/cuda/cuda_ops_api.h +++ b/xllm/core/kernels/cuda/cuda_ops_api.h @@ -45,7 +45,9 @@ void reshape_paged_cache( torch::Tensor key_cache, // [n_blocks, block_size, n_heads, head_dim] torch::Tensor value_cache); -void batch_prefill(torch::Tensor float_workspace_buffer, +void batch_prefill(const std::string& uri, + torch::Tensor plan_info, + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, torch::Tensor query, @@ -59,7 +61,9 @@ void batch_prefill(torch::Tensor float_workspace_buffer, std::optional& output_lse, bool enable_cuda_graph); -void batch_decode(torch::Tensor float_workspace_buffer, +void batch_decode(const std::string& uri, + torch::Tensor plan_info, + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, torch::Tensor query, diff --git a/xllm/core/kernels/cuda/utils.h b/xllm/core/kernels/cuda/utils.h index 5eed28051..8365028f4 100644 --- a/xllm/core/kernels/cuda/utils.h +++ b/xllm/core/kernels/cuda/utils.h @@ -15,6 +15,7 @@ limitations under the License. #pragma once +#include #include #include @@ -60,4 +61,4 @@ std::string get_batch_decode_uri(torch::ScalarType dtype_q, bool use_sliding_window, bool use_logits_soft_cap); -} // namespace xllm::kernel::cuda \ No newline at end of file +} // namespace xllm::kernel::cuda diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index 9dfbf887f..a360ecb9e 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -153,7 +153,9 @@ void batch_prefill(AttentionParams& params) { params.scale, params.output); #elif defined(USE_CUDA) - cuda::batch_prefill(params.float_workspace_buffer, + cuda::batch_prefill(params.uri, + params.plan_info, + params.float_workspace_buffer, params.int_workspace_buffer, params.page_locked_int_workspace_buffer, params.query, @@ -225,7 +227,9 @@ void batch_decode(AttentionParams& params) { #elif defined(USE_CUDA) params.query = params.query.squeeze(1); params.output = params.output.squeeze(1); - cuda::batch_decode(params.float_workspace_buffer, + cuda::batch_decode(params.uri, + params.plan_info, + params.float_workspace_buffer, params.int_workspace_buffer, params.page_locked_int_workspace_buffer, params.query, diff --git a/xllm/core/kernels/param.h b/xllm/core/kernels/param.h index 8acc137a1..13f199aa6 100644 --- a/xllm/core/kernels/param.h +++ b/xllm/core/kernels/param.h @@ -208,6 +208,8 @@ struct AttentionParams { torch::Tensor page_locked_int_workspace_buffer; bool enable_cuda_graph = false; + std::string uri; + torch::Tensor plan_info; // ========== Prefill-specific parameters ========== // Key tensor. Shape: [num_tokens, num_kv_heads, head_dim_qk] (packed) or diff --git a/xllm/core/layers/common/attention_metadata.cpp b/xllm/core/layers/common/attention_metadata.cpp index 95111d56b..d2d8254da 100644 --- a/xllm/core/layers/common/attention_metadata.cpp +++ b/xllm/core/layers/common/attention_metadata.cpp @@ -16,6 +16,9 @@ limitations under the License. #include "attention_metadata.h" #include "core/common/global_flags.h" +#include "core/layers/cuda/flashinfer_workspace.h" +#include "kernels/cuda/function_factory.h" +#include "kernels/cuda/utils.h" namespace xllm { namespace layer { @@ -53,4 +56,4 @@ AttentionMetadata AttentionMetadata::build(const ModelInputParams& params, } } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/common/attention_metadata.h b/xllm/core/layers/common/attention_metadata.h index fed4a785c..a226900da 100644 --- a/xllm/core/layers/common/attention_metadata.h +++ b/xllm/core/layers/common/attention_metadata.h @@ -53,4 +53,4 @@ struct AttentionMetadata { }; } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/common/qwen2_attention.cpp b/xllm/core/layers/common/qwen2_attention.cpp index 9ec0a5fad..d638ec361 100644 --- a/xllm/core/layers/common/qwen2_attention.cpp +++ b/xllm/core/layers/common/qwen2_attention.cpp @@ -103,7 +103,8 @@ Qwen2AttentionImpl::Qwen2AttentionImpl(const ModelContext& context) { // 5. Attention attn_ = register_module("attn", - Attention(num_heads_, + Attention(context.layer_id(), + num_heads_, head_dim_, scaling_, num_kv_heads_, diff --git a/xllm/core/layers/common/qwen2_decoder_layer.h b/xllm/core/layers/common/qwen2_decoder_layer.h index e66f8f681..8dc71b585 100644 --- a/xllm/core/layers/common/qwen2_decoder_layer.h +++ b/xllm/core/layers/common/qwen2_decoder_layer.h @@ -64,7 +64,5 @@ class Qwen2DecoderLayerImpl : public torch::nn::Module { ParallelArgs parallel_args_; }; -using Qwen3DecoderLayerImpl = Qwen2DecoderLayerImpl; - } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/common/qwen3_moe_decoder_layer.cpp b/xllm/core/layers/common/qwen3_moe_decoder_layer.cpp index 8f69a9318..2605d8c5d 100644 --- a/xllm/core/layers/common/qwen3_moe_decoder_layer.cpp +++ b/xllm/core/layers/common/qwen3_moe_decoder_layer.cpp @@ -20,8 +20,8 @@ limitations under the License. namespace xllm { namespace layer { -Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context, - int32_t layer_id) { +Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl( + const ModelContext& context) { const auto& model_args = context.get_model_args(); const auto& quant_args = context.get_quant_args(); const auto& parallel_args = context.get_parallel_args(); @@ -41,10 +41,11 @@ Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context, // Initialize mlp auto mlp_only_layers = model_args.mlp_only_layers(); - if ((std::count(mlp_only_layers.begin(), mlp_only_layers.end(), layer_id) == - 0) && + if ((std::count(mlp_only_layers.begin(), + mlp_only_layers.end(), + context.layer_id()) == 0) && model_args.num_experts() > 0 && - (layer_id + 1) % model_args.decoder_sparse_step() == 0) { + (context.layer_id() + 1) % model_args.decoder_sparse_step() == 0) { moe_mlp_ = register_module("mlp", FusedMoE(model_args.num_experts(), model_args.num_experts_per_tok(), diff --git a/xllm/core/layers/common/qwen3_moe_decoder_layer.h b/xllm/core/layers/common/qwen3_moe_decoder_layer.h index 0423df7e2..a96fa5ea0 100644 --- a/xllm/core/layers/common/qwen3_moe_decoder_layer.h +++ b/xllm/core/layers/common/qwen3_moe_decoder_layer.h @@ -34,8 +34,7 @@ namespace layer { class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { public: - explicit Qwen3MoeDecoderLayerImpl(const ModelContext& context, - int32_t layer_id); + explicit Qwen3MoeDecoderLayerImpl(const ModelContext& context); ~Qwen3MoeDecoderLayerImpl() {}; diff --git a/xllm/core/layers/config.h b/xllm/core/layers/config.h index 6e00883c7..89333331a 100644 --- a/xllm/core/layers/config.h +++ b/xllm/core/layers/config.h @@ -89,6 +89,11 @@ UNIFY_CLASS_NAME(Qwen2_5_VisionLayerImpl, Qwen2dot5VisionEncoderLayerImpl) #include "npu/npu_qwen3_decoder_layer_impl.h" #else #include "common/qwen2_decoder_layer.h" +namespace xllm { +namespace layer { +using Qwen3DecoderLayerImpl = Qwen2DecoderLayerImpl; +} +} // namespace xllm #endif #if defined(USE_NPU) diff --git a/xllm/core/layers/cuda/attention.cpp b/xllm/core/layers/cuda/attention.cpp index 98e3898b3..c221715c0 100644 --- a/xllm/core/layers/cuda/attention.cpp +++ b/xllm/core/layers/cuda/attention.cpp @@ -16,21 +16,29 @@ limitations under the License. #include "attention.h" #include "flashinfer_workspace.h" +#include "kernels/cuda/function_factory.h" +#include "kernels/cuda/utils.h" #include "kernels/ops_api.h" DECLARE_bool(enable_chunked_prefill); namespace xllm { namespace layer { -AttentionImpl::AttentionImpl(int num_heads, + +AttentionImpl::AttentionImpl(int layer_id, + int num_heads, int head_size, float scale, int num_kv_heads, int sliding_window) - : num_heads_(num_heads), + : layer_id_(layer_id), + num_heads_(num_heads), head_size_(head_size), scale_(scale), num_kv_heads_(num_kv_heads), - sliding_window_(sliding_window - 1) {} + sliding_window_(sliding_window - 1) { + CHECK(layer_id >= 0) << "layer_id passed to attention is invalid, layer_id = " + << layer_id; +} std::tuple> AttentionImpl::forward( const AttentionMetadata& attn_metadata, @@ -53,6 +61,23 @@ std::tuple> AttentionImpl::forward( torch::Tensor k_cache = kv_cache.get_k_cache(); torch::Tensor v_cache = kv_cache.get_v_cache(); + // maybe we need to update shared attn state before execute attention, + // currently we update flashinfer step_wise_attn_state_ at layer 0. + step_wise_attn_state_.update( + layer_id_, + attn_metadata, + query.scalar_type(), + key.scalar_type(), + output.scalar_type(), + head_size_, + head_size_, + num_heads_, + num_kv_heads_, + /*block_size*/ k_cache.size(1), + /*window_size_left*/ sliding_window_, + /*enable_cuda_graph*/ false, + /*causal*/ attn_metadata.is_prefill || attn_metadata.is_chunked_prefill); + xllm::kernel::ReshapePagedCacheParams reshape_paged_cache_params; reshape_paged_cache_params.key = key; reshape_paged_cache_params.value = value; @@ -79,6 +104,8 @@ std::tuple> AttentionImpl::forward( .get_page_locked_int_workspace_buffer(); attention_params.kv_cu_seq_lens = attn_metadata.kv_cu_seq_lens; attention_params.q_cu_seq_lens = attn_metadata.q_cu_seq_lens; + attention_params.uri = step_wise_attn_state_.uri; + attention_params.plan_info = step_wise_attn_state_.plan_info; // TODO: support chunked prefill CHECK(!attn_metadata.is_chunked_prefill) @@ -109,5 +136,114 @@ std::tuple> AttentionImpl::forward( return {output, output_lse}; } +void StepwiseAttentionState::update(int layer_id, + const AttentionMetadata& attn_meta, + c10::ScalarType query_dtype, + c10::ScalarType key_dtype, + c10::ScalarType output_dtype, + int head_dim_qk, + int head_dim_vo, + int num_qo_heads, + int num_kv_heads, + int block_size, + int window_size_left, + bool enable_cuda_graph, + bool causal) { + CHECK(layer_id != -1) << "Need to set layer_id to ModelContext or Attention."; + + // for flashinfer + // TODO: check if not flashinfer backend, we return. + if (layer_id != 0) return; + + // we ready flash_planinfo and flashinfer_uri in the first step + if (causal) { + std::string backend = kernel::cuda::determine_attention_backend( + /*pos_encoding_mode=*/0, + /*use_fp16_qk_reduction=*/false, + /*use_custom_mask=*/false); + uri = kernel::cuda::get_batch_prefill_uri( + backend, + query_dtype, + key_dtype, + output_dtype, + attn_meta.q_cu_seq_lens.scalar_type(), + head_dim_qk, + head_dim_vo, + /*pos_encoding_mode=*/0, + /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false, + /*use_fp16_qk_reduction=*/false); + + torch::Tensor qo_indptr_host = attn_meta.q_cu_seq_lens.to(torch::kCPU); + torch::Tensor kv_cu_seq_lens_host = + attn_meta.kv_cu_seq_lens.to(torch::kCPU); + torch::Tensor kv_len_arr_host = + kv_cu_seq_lens_host.slice(0, 1) - kv_cu_seq_lens_host.slice(0, 0, -1); + const int64_t total_num_rows = qo_indptr_host[-1].item(); + const int64_t batch_size = qo_indptr_host.size(0) - 1; + plan_info = + kernel::cuda::FunctionFactory::get_instance() + .prefill_plan_func(uri) + .call( + FlashinferWorkspace::get_instance() + .get_float_workspace_buffer(), + FlashinferWorkspace::get_instance().get_int_workspace_buffer(), + FlashinferWorkspace::get_instance() + .get_page_locked_int_workspace_buffer(), + qo_indptr_host, + kv_cu_seq_lens_host, + kv_len_arr_host, + total_num_rows, + batch_size, + num_qo_heads, + num_kv_heads, + /*page_size=*/1, + enable_cuda_graph, + head_dim_qk, + head_dim_vo, + causal); + } else { + uri = kernel::cuda::get_batch_decode_uri( + query_dtype, + key_dtype, + output_dtype, + attn_meta.paged_kv_indptr.scalar_type(), + head_dim_qk, + head_dim_vo, + /*pos_encoding_mode=*/0, + /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false); + + torch::Tensor paged_kv_indptr_host = + attn_meta.paged_kv_indptr.to(torch::kCPU); + const int64_t batch_size = attn_meta.paged_kv_last_page_len.size(0); + torch::Tensor empty_q_data = + torch::empty({0}, torch::TensorOptions().dtype(query_dtype)); + torch::Tensor empty_kv_data = + torch::empty({0}, torch::TensorOptions().dtype(key_dtype)); + plan_info = + kernel::cuda::FunctionFactory::get_instance() + .decode_plan_func(uri) + .call( + FlashinferWorkspace::get_instance() + .get_float_workspace_buffer(), + FlashinferWorkspace::get_instance().get_int_workspace_buffer(), + FlashinferWorkspace::get_instance() + .get_page_locked_int_workspace_buffer(), + paged_kv_indptr_host, + batch_size, + num_qo_heads, + num_kv_heads, + block_size, + enable_cuda_graph, + window_size_left, + /*logits_soft_cap=*/0.0, + head_dim_qk, + head_dim_vo, + empty_q_data, + empty_kv_data); + } +} + } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/cuda/attention.h b/xllm/core/layers/cuda/attention.h index 21a50029f..17fc6d14d 100644 --- a/xllm/core/layers/cuda/attention.h +++ b/xllm/core/layers/cuda/attention.h @@ -25,11 +25,34 @@ limitations under the License. namespace xllm { namespace layer { + +// for flashinfer, maybe we need to refactor later +struct StepwiseAttentionState { + int layer_id = -1; + torch::Tensor plan_info; + std::string uri; + + void update(int layer_id, + const AttentionMetadata& attn_meta, + c10::ScalarType query_dtype, + c10::ScalarType key_dtype, + c10::ScalarType output_dtype, + int head_dim_qk, + int head_dim_vo, + int num_qo_heads, + int num_kv_heads, + int block_size, + int window_size_left, + bool enable_cuda_graph, + bool causal); +}; + class AttentionImpl : public torch::nn::Module { public: AttentionImpl() = default; - AttentionImpl(int num_heads, + AttentionImpl(int layer_id, + int num_heads, int head_size, float scale, int num_kv_heads, @@ -43,13 +66,17 @@ class AttentionImpl : public torch::nn::Module { KVCache& kv_cache); private: + int layer_id_; int num_heads_; int head_size_; float scale_; int num_kv_heads_; int sliding_window_; + + private: + inline static StepwiseAttentionState step_wise_attn_state_; }; TORCH_MODULE(Attention); } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/deepseek_v2_decoder_layer.h b/xllm/core/layers/deepseek_v2_decoder_layer.h index 39370d90e..6edff6613 100644 --- a/xllm/core/layers/deepseek_v2_decoder_layer.h +++ b/xllm/core/layers/deepseek_v2_decoder_layer.h @@ -28,9 +28,8 @@ class DeepseekV2DecoderLayer using torch::nn::ModuleHolder::ModuleHolder; using Impl __attribute__((__unused__)) = DeepseekV2DecoderLayerImpl; - DeepseekV2DecoderLayer(const ModelContext& context, const int32_t layer_id) - : ModuleHolder( - std::make_shared(context, layer_id)) {} + DeepseekV2DecoderLayer(const ModelContext& context) + : ModuleHolder(std::make_shared(context)) {} }; } // namespace layer diff --git a/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.cpp b/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.cpp index da4fa2ea2..5c47426ce 100644 --- a/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.cpp +++ b/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.cpp @@ -19,8 +19,7 @@ namespace xllm { namespace layer { DeepseekV2DecoderLayerImpl::DeepseekV2DecoderLayerImpl( - const ModelContext& context, - int32_t layer_id) + const ModelContext& context) : parallel_args_(context.get_parallel_args()) { const auto& model_args = context.get_model_args(); const auto& quant_args = context.get_quant_args(); @@ -46,7 +45,7 @@ DeepseekV2DecoderLayerImpl::DeepseekV2DecoderLayerImpl( // Initialize mlp auto first_k_dense_replace = model_args.first_k_dense_replace(); - if (layer_id >= first_k_dense_replace) { + if (context.layer_id() >= first_k_dense_replace) { moe_mlp_ = register_module("mlp", FusedMoE(model_args.n_routed_experts(), model_args.num_experts_per_tok(), diff --git a/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.h b/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.h index 08fd82116..e64c94d79 100644 --- a/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.h +++ b/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.h @@ -36,8 +36,7 @@ namespace layer { class DeepseekV2DecoderLayerImpl : public torch::nn::Module { public: - explicit DeepseekV2DecoderLayerImpl(const ModelContext& context, - int32_t layer_id); + explicit DeepseekV2DecoderLayerImpl(const ModelContext& context); ~DeepseekV2DecoderLayerImpl() {}; diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp index 8b54de329..9c500685e 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp @@ -130,11 +130,10 @@ enum DecoderLayerTensorId : int { static const uint64_t WEIGHT_COUNT_PER_LAYER = 84; DeepseekV2DecoderLayerImpl::DeepseekV2DecoderLayerImpl( - const ModelContext& context, - const int32_t layer_id) + const ModelContext& context) : BaseLayer(context), device_id_(context.get_tensor_options().device().index()), - layer_id_(layer_id), + layer_id_(context.layer_id()), num_speculative_tokens_( context.get_model_args().num_speculative_tokens()) { // compute sm_scale diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h index 012ef9c00..3260e2a45 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h @@ -106,8 +106,7 @@ class ExpertBuffer { class DeepseekV2DecoderLayerImpl : public BaseLayer { public: - explicit DeepseekV2DecoderLayerImpl(const ModelContext& context, - const int32_t layer_id); + explicit DeepseekV2DecoderLayerImpl(const ModelContext& context); ~DeepseekV2DecoderLayerImpl() {}; diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp index 919209541..7c0690013 100644 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp @@ -26,11 +26,10 @@ namespace layer { static const uint64_t WEIGHT_COUNT_PER_LAYER = 55; -Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context, - const int32_t layer_id) +Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context) : BaseLayer(context), device_id_(context.get_tensor_options().device().index()), - layer_id_(layer_id), + layer_id_(context.layer_id()), num_speculative_tokens_( context.get_model_args().num_speculative_tokens()) { auto model_args = context.get_model_args(); diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h index e73bad5fc..3b445c824 100644 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h @@ -39,8 +39,7 @@ namespace layer { class Qwen3MoeDecoderLayerImpl : public BaseLayer { public: - explicit Qwen3MoeDecoderLayerImpl(const ModelContext& context, - const int32_t layer_id); + explicit Qwen3MoeDecoderLayerImpl(const ModelContext& context); ~Qwen3MoeDecoderLayerImpl() {}; diff --git a/xllm/core/layers/qwen3_moe_decoder_layer.h b/xllm/core/layers/qwen3_moe_decoder_layer.h index 84dd9ac53..87649511c 100644 --- a/xllm/core/layers/qwen3_moe_decoder_layer.h +++ b/xllm/core/layers/qwen3_moe_decoder_layer.h @@ -26,9 +26,9 @@ class Qwen3MoeDecoderLayer using torch::nn::ModuleHolder::ModuleHolder; using Impl __attribute__((__unused__)) = Qwen3MoeDecoderLayerImpl; - Qwen3MoeDecoderLayer(const ModelContext& context, int32_t layer_id) + Qwen3MoeDecoderLayer(const ModelContext& context) : Qwen3MoeDecoderLayer( - std::make_shared(context, layer_id)) {} + std::make_shared(context)) {} }; } // namespace layer diff --git a/xllm/models/dit/autoencoder_kl.h b/xllm/models/dit/autoencoder_kl.h index e1e948b0e..48ff63e3a 100644 --- a/xllm/models/dit/autoencoder_kl.h +++ b/xllm/models/dit/autoencoder_kl.h @@ -704,7 +704,9 @@ class DownEncoderBlock2DImpl : public torch::nn::Module { // initialize resnet blocks for (int64_t i = 0; i < num_layers; ++i) { const int64_t current_in_channels = (i == 0) ? in_channels : out_channels; - auto block = ResnetBlock2D(context, + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = ResnetBlock2D(curr_context, current_in_channels, // in channels out_channels, "default"); diff --git a/xllm/models/dit/clip_text_model.h b/xllm/models/dit/clip_text_model.h index 4dcaa603e..29c89855e 100644 --- a/xllm/models/dit/clip_text_model.h +++ b/xllm/models/dit/clip_text_model.h @@ -496,7 +496,9 @@ class CLIPEncoderImpl : public torch::nn::Module { blocks_ = register_module("layers", torch::nn::ModuleList()); layers_.reserve(model_args.mm_num_hidden_layers()); for (int32_t i = 0; i < model_args.mm_num_hidden_layers(); i++) { - auto block = CLIPEncoderLayer(context); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = CLIPEncoderLayer(curr_context); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/dit/dit.h b/xllm/models/dit/dit.h index 570a4ec4f..245fbea05 100644 --- a/xllm/models/dit/dit.h +++ b/xllm/models/dit/dit.h @@ -1242,14 +1242,18 @@ class FluxTransformer2DModelImpl : public torch::nn::Module { // mm-dit block transformer_block_layers_.reserve(num_layers); for (int64_t i = 0; i < num_layers; ++i) { - auto block = FluxTransformerBlock(context); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = FluxTransformerBlock(curr_context); transformer_blocks_->push_back(block); transformer_block_layers_.push_back(block); } // single mm-dit block single_transformer_block_layers_.reserve(num_single_layers); for (int64_t i = 0; i < num_single_layers; ++i) { - auto block = FluxSingleTransformerBlock(context); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = FluxSingleTransformerBlock(curr_context); single_transformer_blocks_->push_back(block); single_transformer_block_layers_.push_back(block); } diff --git a/xllm/models/llm/deepseek_mtp.h b/xllm/models/llm/deepseek_mtp.h index 1e6a2ef7c..c57a50f25 100644 --- a/xllm/models/llm/deepseek_mtp.h +++ b/xllm/models/llm/deepseek_mtp.h @@ -31,8 +31,7 @@ namespace xllm { class DeepseekMultiTokenPredictorLayerImpl : public torch::nn::Module { public: - DeepseekMultiTokenPredictorLayerImpl(const ModelContext& context, - const int32_t layer_index) { + DeepseekMultiTokenPredictorLayerImpl(const ModelContext& context) { auto options = context.get_tensor_options(); auto model_args = context.get_model_args(); auto parallel_args = context.get_parallel_args(); @@ -48,8 +47,8 @@ class DeepseekMultiTokenPredictorLayerImpl : public torch::nn::Module { /*bias=*/false, /*QuantArgs=*/QuantArgs(), options)); - mtp_block_ = register_module( - "mtp_block", layer::DeepseekV2DecoderLayer(context, layer_index)); + mtp_block_ = + register_module("mtp_block", layer::DeepseekV2DecoderLayer(context)); } torch::Tensor forward(torch::Tensor embed, @@ -120,7 +119,9 @@ class DeepseekMTPModelImpl : public torch::nn::Module { // create mtp layers for (int32_t i = mtp_start_layer_idx_; i < mtp_end_layer_idx_; ++i) { - auto mtp_layer = DeepseekMultiTokenPredictorLayer(context, i); + auto curr_context = context; + curr_context.set_layer_id(i); + auto mtp_layer = DeepseekMultiTokenPredictorLayer(curr_context); mtp_layers_.push_back(mtp_layer); blocks_->push_back(mtp_layer); } diff --git a/xllm/models/llm/deepseek_v2.h b/xllm/models/llm/deepseek_v2.h index 3fec6d6f2..53dfb5026 100644 --- a/xllm/models/llm/deepseek_v2.h +++ b/xllm/models/llm/deepseek_v2.h @@ -30,11 +30,10 @@ namespace xllm { class DeepseekV2DecoderLayerImpl : public torch::nn::Module { public: - DeepseekV2DecoderLayerImpl(const ModelContext& context, - const int32_t layer_index) { + DeepseekV2DecoderLayerImpl(const ModelContext& context) { // register submodules - decoder_layer_ = register_module( - "decoder_layer", layer::DeepseekV2DecoderLayer(context, layer_index)); + decoder_layer_ = register_module("decoder_layer", + layer::DeepseekV2DecoderLayer(context)); } torch::Tensor forward(torch::Tensor& x, @@ -85,7 +84,9 @@ class DeepseekV2ModelImpl : public torch::nn::Module { // create decoder layers for (int32_t i = 0; i < model_args.n_layers(); ++i) { - auto block = DeepseekV2DecoderLayer(context, i); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = DeepseekV2DecoderLayer(curr_context); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/llm/npu/deepseek_v2.h b/xllm/models/llm/npu/deepseek_v2.h index cd08cdd1b..d276b46ac 100644 --- a/xllm/models/llm/npu/deepseek_v2.h +++ b/xllm/models/llm/npu/deepseek_v2.h @@ -46,10 +46,10 @@ using ISlice = torch::indexing::Slice; class DeepseekV2DecoderLayerImpl : public torch::nn::Module { public: - DeepseekV2DecoderLayerImpl(const ModelContext& context, const int32_t i) { + DeepseekV2DecoderLayerImpl(const ModelContext& context) { // register submodules decoder_layer_ = register_module("decoder_layer", - layer::DeepseekV2DecoderLayer(context, i)); + layer::DeepseekV2DecoderLayer(context)); } torch::Tensor forward(torch::Tensor& x, @@ -121,7 +121,9 @@ class DeepseekV2ModelImpl : public torch::nn::Module { /*mask_value=*/mask_value); for (int32_t i = 0; i < model_args.n_layers(); ++i) { - auto block = DeepseekV2DecoderLayer(context, i); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = DeepseekV2DecoderLayer(curr_context, i); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/llm/npu/deepseek_v2_mtp.h b/xllm/models/llm/npu/deepseek_v2_mtp.h index 8a7f3da56..f8ea8f62c 100644 --- a/xllm/models/llm/npu/deepseek_v2_mtp.h +++ b/xllm/models/llm/npu/deepseek_v2_mtp.h @@ -66,7 +66,9 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module { atb_pos_emb_ = layer::PosEmbedding(context); for (int32_t i = 0; i < model_args.n_layers(); ++i) { - auto block = DeepseekV2DecoderLayer(context, i); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = DeepseekV2DecoderLayer(curr_context, i); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/llm/npu/glm4.h b/xllm/models/llm/npu/glm4.h index 6e12427d7..b683a6c2b 100644 --- a/xllm/models/llm/npu/glm4.h +++ b/xllm/models/llm/npu/glm4.h @@ -58,7 +58,9 @@ class Glm4ModelImpl : public LlmModelImplBase { /*mask_value=*/mask_value); for (int32_t i = 0; i < model_args.n_layers(); i++) { - auto block = Glm4DecoderLayer(context); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = Glm4DecoderLayer(curr_context); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/llm/npu/glm4_moe.h b/xllm/models/llm/npu/glm4_moe.h index 244a98f2e..940c5a6e4 100644 --- a/xllm/models/llm/npu/glm4_moe.h +++ b/xllm/models/llm/npu/glm4_moe.h @@ -101,7 +101,9 @@ class Glm4MoeModelImpl : public torch::nn::Module { options.dtype().toScalarType(), /*mask_value=*/mask_value); for (int32_t i = 0; i < model_args.n_layers(); ++i) { - auto block = Glm4MoeDecoderLayer(context, i); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = Glm4MoeDecoderLayer(curr_context, i); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/llm/npu/glm4_moe_mtp.h b/xllm/models/llm/npu/glm4_moe_mtp.h index f744f887d..173c6dbb5 100644 --- a/xllm/models/llm/npu/glm4_moe_mtp.h +++ b/xllm/models/llm/npu/glm4_moe_mtp.h @@ -56,7 +56,9 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module { /*mask_value=*/mask_value); for (int32_t i = 0; i < model_args.n_layers(); ++i) { - auto block = Glm4MoeDecoderLayer(context, i); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = Glm4MoeDecoderLayer(curr_context, i); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/llm/npu/llama.h b/xllm/models/llm/npu/llama.h index 5af3b1d08..ea9a2a7f1 100644 --- a/xllm/models/llm/npu/llama.h +++ b/xllm/models/llm/npu/llama.h @@ -133,7 +133,9 @@ class LlamaModelImpl : public torch::nn::Module { max_seq_len_ = 0; for (int32_t i = 0; i < model_args.n_layers(); i++) { - auto block = LlamaDecoderLayer(context); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = LlamaDecoderLayer(curr_context); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/llm/npu/qwen2.h b/xllm/models/llm/npu/qwen2.h index 1d29a5a4c..4c16cb5ab 100644 --- a/xllm/models/llm/npu/qwen2.h +++ b/xllm/models/llm/npu/qwen2.h @@ -62,7 +62,9 @@ class QWen2ModelImpl : public LlmModelImplBase { /*mask_value=*/mask_value); for (int32_t i = 0; i < model_args.n_layers(); i++) { - auto block = QWen2DecoderLayer(context); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = QWen2DecoderLayer(curr_context); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/llm/npu/qwen3.h b/xllm/models/llm/npu/qwen3.h index 3afcc6064..820afb095 100644 --- a/xllm/models/llm/npu/qwen3.h +++ b/xllm/models/llm/npu/qwen3.h @@ -61,7 +61,9 @@ class QWen3ModelImpl : public LlmModelImplBase { /*mask_value=*/mask_value); for (int32_t i = 0; i < model_args.n_layers(); i++) { - auto block = QWen3DecoderLayer(context); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = QWen3DecoderLayer(curr_context); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/llm/npu/qwen3_moe.h b/xllm/models/llm/npu/qwen3_moe.h index 6912b0f12..ffcc6dce2 100644 --- a/xllm/models/llm/npu/qwen3_moe.h +++ b/xllm/models/llm/npu/qwen3_moe.h @@ -32,10 +32,10 @@ using ISlice = torch::indexing::Slice; class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { public: - Qwen3MoeDecoderLayerImpl(const ModelContext& context, const int32_t i) { + Qwen3MoeDecoderLayerImpl(const ModelContext& context) { // register submodules - decoder_layer_ = register_module("decoder_layer", - layer::Qwen3MoeDecoderLayer(context, i)); + decoder_layer_ = + register_module("decoder_layer", layer::Qwen3MoeDecoderLayer(context)); } torch::Tensor forward(torch::Tensor x, @@ -143,7 +143,9 @@ class Qwen3MoeModelImpl : public torch::nn::Module { mapping_data_ = parallel_args.mapping_data(); for (int32_t i = 0; i < model_args.n_layers(); ++i) { - auto block = Qwen3MoeDecoderLayer(context, i); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = Qwen3MoeDecoderLayer(curr_context); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/llm/qwen2.h b/xllm/models/llm/qwen2.h index 1172f7465..b6fb845fa 100644 --- a/xllm/models/llm/qwen2.h +++ b/xllm/models/llm/qwen2.h @@ -58,7 +58,9 @@ class QWen2ModelImpl : public LlmModelImplBase { options); for (int32_t i = 0; i < model_args.n_layers(); i++) { - auto block = QWen2DecoderLayer(context); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = QWen2DecoderLayer(curr_context); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/llm/qwen3.h b/xllm/models/llm/qwen3.h index 7be92a12e..2dd2634d1 100644 --- a/xllm/models/llm/qwen3.h +++ b/xllm/models/llm/qwen3.h @@ -54,7 +54,9 @@ class QWen3ModelImpl : public LlmModelImplBase { int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; for (int32_t i = 0; i < model_args.n_layers(); i++) { - auto block = QWen3DecoderLayer(context); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = QWen3DecoderLayer(curr_context); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index 066ebcf51..0cbb7f270 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -31,10 +31,10 @@ using ISlice = torch::indexing::Slice; class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { public: - Qwen3MoeDecoderLayerImpl(const ModelContext& context, const int32_t i) { + Qwen3MoeDecoderLayerImpl(const ModelContext& context) { // register submodules - decoder_layer_ = register_module("decoder_layer", - layer::Qwen3MoeDecoderLayer(context, i)); + decoder_layer_ = + register_module("decoder_layer", layer::Qwen3MoeDecoderLayer(context)); } torch::Tensor forward(torch::Tensor& x, @@ -121,7 +121,9 @@ class Qwen3MoeModelImpl : public torch::nn::Module { mapping_data_ = parallel_args.mapping_data(); for (int32_t i = 0; i < model_args.n_layers(); ++i) { - auto block = Qwen3MoeDecoderLayer(context, i); + auto curr_context = context; + curr_context.set_layer_id(i); + auto block = Qwen3MoeDecoderLayer(curr_context); layers_.push_back(block); blocks_->push_back(block); } diff --git a/xllm/models/vlm/npu/glm4v.h b/xllm/models/vlm/npu/glm4v.h index 6ab42ec7e..6470ce854 100644 --- a/xllm/models/vlm/npu/glm4v.h +++ b/xllm/models/vlm/npu/glm4v.h @@ -598,7 +598,9 @@ class Glm4VisionTransformerImpl : public torch::nn::Module { blocks_ = register_module("blocks", torch::nn::ModuleList()); for (int32_t idx = 0; idx < model_args.mm_num_hidden_layers(); idx++) { - auto block = Glm4_VisionBlock(context); + auto curr_context = context; + curr_context.set_layer_id(idx); + auto block = Glm4_VisionBlock(curr_context); blocks_->push_back(block); layers_.push_back(block); } diff --git a/xllm/models/vlm/npu/qwen2_5_vl.h b/xllm/models/vlm/npu/qwen2_5_vl.h index 4c7ca3c7b..57c55866f 100644 --- a/xllm/models/vlm/npu/qwen2_5_vl.h +++ b/xllm/models/vlm/npu/qwen2_5_vl.h @@ -398,7 +398,9 @@ class Qwen2_5_VisionTransformerImpl : public torch::nn::Module { blocks_ = register_module("blocks", torch::nn::ModuleList()); for (int32_t idx = 0; idx < model_args.mm_num_hidden_layers(); idx++) { - auto block = Qwen2_5_VisionBlock(context); + auto curr_context = context; + curr_context.set_layer_id(idx); + auto block = Qwen2_5_VisionBlock(curr_context); blocks_->push_back(block); layers_.push_back(block); } diff --git a/xllm/models/vlm/npu/qwen3_vl.h b/xllm/models/vlm/npu/qwen3_vl.h index 42b7bd197..a43f22be5 100644 --- a/xllm/models/vlm/npu/qwen3_vl.h +++ b/xllm/models/vlm/npu/qwen3_vl.h @@ -349,12 +349,16 @@ class Qwen3_VisionTransformerImpl : public torch::nn::Module { merger_ = register_module("merger", Qwen3_VisionPatchMerger(context)); for (int32_t idx = 0; idx < model_args.mm_num_hidden_layers(); idx++) { - auto block = Qwen3_VisionBlock(context); + auto curr_context = context; + curr_context.set_layer_id(idx); + auto block = Qwen3_VisionBlock(curr_context); blocks_->push_back(block); layers_.push_back(block); } for (int32_t idx = 0; idx < deepstack_visual_indexes_.size(); idx++) { - auto merger = Qwen3_VisionPatchMerger(context, true); + auto curr_context = context; + curr_context.set_layer_id(idx); + auto merger = Qwen3_VisionPatchMerger(curr_context, true); deepstack_mergers_->push_back(merger); deepstack_merger_layers_.push_back(merger); } diff --git a/xllm/models/vlm/qwen2_5_vl.h b/xllm/models/vlm/qwen2_5_vl.h index 3556b7c13..38b18069f 100644 --- a/xllm/models/vlm/qwen2_5_vl.h +++ b/xllm/models/vlm/qwen2_5_vl.h @@ -378,7 +378,9 @@ class Qwen2_5_VisionTransformerImpl : public torch::nn::Module { blocks_ = register_module("blocks", torch::nn::ModuleList()); for (int32_t idx = 0; idx < model_args.mm_num_hidden_layers(); idx++) { - auto block = Qwen2_5_VisionBlock(context); + auto curr_context = context; + curr_context.set_layer_id(idx); + auto block = Qwen2_5_VisionBlock(curr_context); blocks_->push_back(block); layers_.push_back(block); } diff --git a/xllm/models/vlm/qwen3_vl.h b/xllm/models/vlm/qwen3_vl.h index b95e6aec8..d93b62649 100755 --- a/xllm/models/vlm/qwen3_vl.h +++ b/xllm/models/vlm/qwen3_vl.h @@ -341,12 +341,16 @@ class Qwen3_VisionTransformerImpl : public torch::nn::Module { merger_ = register_module("merger", Qwen3_VisionPatchMerger(context)); for (int32_t idx = 0; idx < model_args.mm_num_hidden_layers(); idx++) { - auto block = Qwen3_VisionBlock(context); + auto curr_context = context; + curr_context.set_layer_id(idx); + auto block = Qwen3_VisionBlock(curr_context); blocks_->push_back(block); layers_.push_back(block); } for (int32_t idx = 0; idx < deepstack_visual_indexes_.size(); idx++) { - auto merger = Qwen3_VisionPatchMerger(context, true); + auto curr_context = context; + curr_context.set_layer_id(idx); + auto merger = Qwen3_VisionPatchMerger(curr_context, true); deepstack_mergers_->push_back(merger); deepstack_merger_layers_.push_back(merger); }