Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xllm/core/layers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ cc_library(
llama_decoder_layer.h
common/multi_head_attention.h
qwen2_decoder_layer.h
qwen2_vision_encode_layer.h
qwen2dot5_vision_encode_layer.h
qwen3_vision_encode_layer.h
qwen3_decoder_layer.h
Expand Down
11 changes: 11 additions & 0 deletions xllm/core/layers/common/qwen2_5_vision_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ torch::Tensor Qwen2_5_VisionLayerImpl::forward(
return output;
}

Qwen2_VisionLayerImpl::Qwen2_VisionLayerImpl(const ModelContext& context)
: Qwen2_5_VisionLayerImpl(context, true) {}

void Qwen2_VisionLayerImpl::load_state_dict(const StateDict& state_dict) {
attention_->load_state_dict(state_dict.get_dict_with_prefix("attn."));
mlp_->load_state_dict(
state_dict.get_dict_with_prefix("mlp."), {"fc1."}, "fc2.");
norm1_->load_state_dict(state_dict.get_dict_with_prefix("norm1."));
norm2_->load_state_dict(state_dict.get_dict_with_prefix("norm2."));
}

Qwen3_VisionLayerImpl::Qwen3_VisionLayerImpl(const ModelContext& context)
: Qwen2_5_VisionLayerImpl(context, true) {}

Expand Down
6 changes: 6 additions & 0 deletions xllm/core/layers/common/qwen2_5_vision_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ class Qwen2_5_VisionLayerImpl : public torch::nn::Module {
RMSNorm norm2_{nullptr};
};

class Qwen2_VisionLayerImpl : public Qwen2_5_VisionLayerImpl {
public:
Qwen2_VisionLayerImpl(const ModelContext& context);
void load_state_dict(const StateDict& state_dict);
};

class Qwen3_VisionLayerImpl : public Qwen2_5_VisionLayerImpl {
public:
Qwen3_VisionLayerImpl(const ModelContext& context);
Expand Down
7 changes: 7 additions & 0 deletions xllm/core/layers/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ REGISTER_NOT_IMPLEMENTED_CLASS(LlamaDecoderLayerImpl);
#include "common/qwen2_decoder_layer.h"
#endif

#if defined(USE_NPU)
#include "npu/npu_qwen2_vision_encoder_layer_impl.h"
#else
#include "common/qwen2_5_vision_layer.h"
UNIFY_CLASS_NAME(Qwen2_VisionLayerImpl, Qwen2VisionEncoderLayerImpl)
#endif

#if defined(USE_NPU)
#include "npu/npu_qwen2dot5_vision_encoder_layer_impl.h"
#else
Expand Down
4 changes: 4 additions & 0 deletions xllm/core/layers/npu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ cc_library(
npu_word_embedding_impl.h
npu_pos_embedding_impl.h
npu_lm_head_impl.h
npu_qwen2_vision_encoder_layer_impl.h
npu_qwen2dot5_vision_encoder_layer_impl.h
npu_qwen3_vision_encoder_layer_impl.h
npu_qwen3_moe_decoder_layer_impl.h
Expand Down Expand Up @@ -38,6 +39,7 @@ cc_library(
loader/deepseek_v2_decoder_loader.h
loader/glm4_moe_decoder_loader.h
loader/llama_decoder_loader.h
loader/qwen2_vision_encoder_loader.h
loader/qwen2dot5_vision_encoder_loader.h
loader/qwen3_vision_encoder_loader.h
loader/rms_norm_loader.h
Expand All @@ -50,6 +52,7 @@ cc_library(
npu_word_embedding_impl.cpp
npu_pos_embedding_impl.cpp
npu_lm_head_impl.cpp
npu_qwen2_vision_encoder_layer_impl.cpp
npu_qwen2dot5_vision_encoder_layer_impl.cpp
npu_qwen3_vision_encoder_layer_impl.cpp
npu_qwen3_moe_decoder_layer_impl.cpp
Expand Down Expand Up @@ -79,6 +82,7 @@ cc_library(
loader/deepseek_v2_decoder_loader.cpp
loader/glm4_moe_decoder_loader.cpp
loader/llama_decoder_loader.cpp
loader/qwen2_vision_encoder_loader.cpp
loader/qwen2dot5_vision_encoder_loader.cpp
loader/qwen3_vision_encoder_loader.cpp
loader/rms_norm_loader.cpp
Expand Down
162 changes: 162 additions & 0 deletions xllm/core/layers/npu/loader/qwen2_vision_encoder_loader.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#ifdef TORCH_HIGHER_THAN_PTA6
#include <torch_npu/csrc/core/npu/NPUFormat.h>
#include <torch_npu/csrc/framework/OpCommand.h>
#else
#include <torch_npu/csrc/aten/NPUNativeFunctions.h>
#include <torch_npu/csrc/framework/utils/OpPreparation.h>
#endif

#include <torch_npu/csrc/libs/init_npu.h>

#include "qwen2_vision_encoder_loader.h"
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
#include "torch_npu/csrc/core/npu/NPUException.h"

namespace xllm {
namespace layer {

enum VisionEncoderLayerTensorId : int {
IN_INPUT_NORM_WEIGHT = 0,
IN_INPUT_NORM_BIAS,
IN_POST_NORM_WEIGHT,
IN_POST_NORM_BIAS,
IN_QKV_WEIGHT,
IN_QKV_BIAS,
IN_WATTENTION_OUT_WEIGHT,
IN_WATTENTION_OUT_BIAS,
IN_LINEAR_FC1_WEIGHT,
IN_LINEAR_FC1_BIAS,
IN_LINEAR_FC2_WEIGHT,
IN_LINEAR_FC2_BIAS,
IN_VISION_Q_WEIGHT,
IN_VISION_Q_BIAS,
IN_VISION_K_WEIGHT,
IN_VISION_K_BIAS,
IN_VISION_V_WEIGHT,
IN_VISION_V_BIAS
};

static std::vector<std::pair<int, std::string>> WEIGHT_MAPPING = {
{IN_INPUT_NORM_WEIGHT, "norm1.weight"},
{IN_INPUT_NORM_BIAS, "norm1.bias"},
{IN_POST_NORM_WEIGHT, "norm2.weight"},
{IN_POST_NORM_BIAS, "norm2.bias"},
{IN_QKV_WEIGHT, "attn.qkv.weight"},
{IN_QKV_BIAS, "attn.qkv.bias"},
{IN_WATTENTION_OUT_WEIGHT, "attn.proj.weight"},
{IN_WATTENTION_OUT_BIAS, "attn.proj.bias"},
{IN_LINEAR_FC1_WEIGHT, "mlp.fc1.weight"},
{IN_LINEAR_FC1_BIAS, "mlp.fc1.bias"},
{IN_LINEAR_FC2_WEIGHT, "mlp.fc2.weight"},
{IN_LINEAR_FC2_BIAS, "mlp.fc2.bias"}};

// {weight,dim}
static std::map<int, int> WEIGHT_SHARD = {
{IN_WATTENTION_OUT_WEIGHT, 1},
{IN_LINEAR_FC1_WEIGHT, 0},
{IN_LINEAR_FC1_BIAS, 0},
{IN_LINEAR_FC2_WEIGHT, 1},
};

Qwen2VisionEncoderLoader::Qwen2VisionEncoderLoader(uint64_t weight_count,
const ModelContext& context)
: BaseLoader(weight_count, context) {
auto model_args = context.get_model_args();
auto parallel_args = context.get_parallel_args();
auto options = context.get_tensor_options();
encode_param_rank = parallel_args.rank();
encode_param_worldSize = parallel_args.world_size();
at_weight_tensors_.resize(weight_count);
dtype_ = torch::typeMetaToScalarType(options.dtype());
device_id_ = options.device().index();
at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_);
for (int i = 0; i < weight_count; ++i) {
at_weight_tensors_[i] = torch::zeros({1}).to(options);
}
}

void Qwen2VisionEncoderLoader::load_state_dict(const StateDict& state_dict) {
for (const auto& [index, name] : WEIGHT_MAPPING) {
if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) {
set_weight(state_dict, name, index, WEIGHT_SHARD[index]);
} else {
set_weight(state_dict, name, index);
}
}
}

void Qwen2VisionEncoderLoader::verify_loaded_weights() const {
for (const auto& [index, name] : WEIGHT_MAPPING) {
CHECK(at_weight_tensors_[index].sizes() != std::vector<int64_t>({1}))
<< "weight is not loaded for " << name;
}
}

void Qwen2VisionEncoderLoader::merge_loaded_weights() {
// spilt pack qkv weight when enable tp
get_weights_col_packed_qkv();
if (encode_param_worldSize > 1) {
// merge qkv weight
auto new_qkv_weight = torch::cat({at_weight_tensors_[IN_VISION_Q_WEIGHT],
at_weight_tensors_[IN_VISION_K_WEIGHT],
at_weight_tensors_[IN_VISION_V_WEIGHT]},
0);
at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight;
at_weight_tensors_[IN_VISION_Q_WEIGHT] = torch::zeros({1}).to(device_);
at_weight_tensors_[IN_VISION_K_WEIGHT] = torch::zeros({1}).to(device_);
at_weight_tensors_[IN_VISION_V_WEIGHT] = torch::zeros({1}).to(device_);

// merge qkv bias
auto new_qkv_bias = torch::cat({at_weight_tensors_[IN_VISION_Q_BIAS],
at_weight_tensors_[IN_VISION_K_BIAS],
at_weight_tensors_[IN_VISION_V_BIAS]},
0);
at_weight_tensors_[IN_QKV_BIAS] = new_qkv_bias;
at_weight_tensors_[IN_VISION_Q_BIAS] = torch::zeros({1}).to(device_);
at_weight_tensors_[IN_VISION_K_BIAS] = torch::zeros({1}).to(device_);
at_weight_tensors_[IN_VISION_V_BIAS] = torch::zeros({1}).to(device_);
}
}

// tp spilt weight
void Qwen2VisionEncoderLoader::get_weights_col_packed_qkv() {
int rank = encode_param_rank;
int worldSize = encode_param_worldSize;
// split qkv weight
qkv_weight = torch::chunk(at_weight_tensors_[IN_QKV_WEIGHT], 3, 0);
qkv_bias = torch::chunk(at_weight_tensors_[IN_QKV_BIAS], 3, 0);
// weight
at_weight_tensors_[IN_VISION_Q_WEIGHT] =
(qkv_weight[0].chunk(worldSize, 0))[rank];
at_weight_tensors_[IN_VISION_K_WEIGHT] =
(qkv_weight[1].chunk(worldSize, 0))[rank];
at_weight_tensors_[IN_VISION_V_WEIGHT] =
(qkv_weight[2].chunk(worldSize, 0))[rank];
// bias
at_weight_tensors_[IN_VISION_Q_BIAS] =
(qkv_bias[0].chunk(worldSize, 0))[rank];
at_weight_tensors_[IN_VISION_K_BIAS] =
(qkv_bias[1].chunk(worldSize, 0))[rank];
at_weight_tensors_[IN_VISION_V_BIAS] =
(qkv_bias[2].chunk(worldSize, 0))[rank];
}

} // namespace layer
} // namespace xllm
49 changes: 49 additions & 0 deletions xllm/core/layers/npu/loader/qwen2_vision_encoder_loader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <map>
#include <vector>

#include "base_loader.h"

namespace xllm {
namespace layer {

class Qwen2VisionEncoderLoader : public BaseLoader {
public:
Qwen2VisionEncoderLoader(uint64_t weight_count, const ModelContext& context);

void load_state_dict(const StateDict& state_dict) override;
void verify_loaded_weights() const override;
void merge_loaded_weights() override;

private:
void get_weights_col_packed_qkv();

protected:
std::string model_name_;
at::Tensor cu_seqlen_;
at::Tensor at_placeholder_;
std::vector<torch::Tensor> qkv_weight;
std::vector<torch::Tensor> qkv_bias;
int device_id_;
int encode_param_rank;
int encode_param_worldSize;
};

} // namespace layer
} // namespace xllm
Loading
Loading