Skip to content

Commit ae16a2c

Browse files
committed
feat: support Qwen2-VL & GME-Qwen2-VL model on npu device.
1 parent 8a00ad7 commit ae16a2c

13 files changed

+2176
-28
lines changed

xllm/core/layers/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ cc_library(
4444
llama_decoder_layer.h
4545
common/multi_head_attention.h
4646
qwen2_decoder_layer.h
47+
qwen2_vision_encode_layer.h
4748
qwen2dot5_vision_encode_layer.h
4849
qwen3_vision_encode_layer.h
4950
qwen3_decoder_layer.h

xllm/core/layers/npu/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ cc_library(
99
npu_word_embedding_impl.h
1010
npu_pos_embedding_impl.h
1111
npu_lm_head_impl.h
12+
npu_qwen2_vision_encoder_layer_impl.h
1213
npu_qwen2dot5_vision_encoder_layer_impl.h
1314
npu_qwen3_vision_encoder_layer_impl.h
1415
npu_qwen3_moe_decoder_layer_impl.h
@@ -38,6 +39,7 @@ cc_library(
3839
loader/deepseek_v2_decoder_loader.h
3940
loader/glm4_moe_decoder_loader.h
4041
loader/llama_decoder_loader.h
42+
loader/qwen2_vision_encoder_loader.h
4143
loader/qwen2dot5_vision_encoder_loader.h
4244
loader/qwen3_vision_encoder_loader.h
4345
loader/rms_norm_loader.h
@@ -50,6 +52,7 @@ cc_library(
5052
npu_word_embedding_impl.cpp
5153
npu_pos_embedding_impl.cpp
5254
npu_lm_head_impl.cpp
55+
npu_qwen2_vision_encoder_layer_impl.cpp
5356
npu_qwen2dot5_vision_encoder_layer_impl.cpp
5457
npu_qwen3_vision_encoder_layer_impl.cpp
5558
npu_qwen3_moe_decoder_layer_impl.cpp
@@ -79,6 +82,7 @@ cc_library(
7982
loader/deepseek_v2_decoder_loader.cpp
8083
loader/glm4_moe_decoder_loader.cpp
8184
loader/llama_decoder_loader.cpp
85+
loader/qwen2_vision_encoder_loader.cpp
8286
loader/qwen2dot5_vision_encoder_loader.cpp
8387
loader/qwen3_vision_encoder_loader.cpp
8488
loader/rms_norm_loader.cpp
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#ifdef TORCH_HIGHER_THAN_PTA6
19+
#include <torch_npu/csrc/core/npu/NPUFormat.h>
20+
#include <torch_npu/csrc/framework/OpCommand.h>
21+
#else
22+
#include <torch_npu/csrc/aten/NPUNativeFunctions.h>
23+
#include <torch_npu/csrc/framework/utils/OpPreparation.h>
24+
#endif
25+
26+
#include <torch_npu/csrc/libs/init_npu.h>
27+
28+
#include "qwen2_vision_encoder_loader.h"
29+
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
30+
#include "torch_npu/csrc/core/npu/NPUException.h"
31+
32+
namespace xllm {
33+
namespace layer {
34+
35+
enum VisionEncoderLayerTensorId : int {
36+
IN_INPUT_NORM_WEIGHT = 0,
37+
IN_INPUT_NORM_BIAS,
38+
IN_POST_NORM_WEIGHT,
39+
IN_POST_NORM_BIAS,
40+
IN_QKV_WEIGHT,
41+
IN_QKV_BIAS,
42+
IN_WATTENTION_OUT_WEIGHT,
43+
IN_WATTENTION_OUT_BIAS,
44+
IN_LINEAR_FC1_WEIGHT,
45+
IN_LINEAR_FC1_BIAS,
46+
IN_LINEAR_FC2_WEIGHT,
47+
IN_LINEAR_FC2_BIAS,
48+
IN_VISION_Q_WEIGHT,
49+
IN_VISION_Q_BIAS,
50+
IN_VISION_K_WEIGHT,
51+
IN_VISION_K_BIAS,
52+
IN_VISION_V_WEIGHT,
53+
IN_VISION_V_BIAS
54+
};
55+
56+
static std::vector<std::pair<int, std::string>> WEIGHT_MAPPING = {
57+
{IN_INPUT_NORM_WEIGHT, "norm1.weight"},
58+
{IN_INPUT_NORM_BIAS, "norm1.bias"},
59+
{IN_POST_NORM_WEIGHT, "norm2.weight"},
60+
{IN_POST_NORM_BIAS, "norm2.bias"},
61+
{IN_QKV_WEIGHT, "attn.qkv.weight"},
62+
{IN_QKV_BIAS, "attn.qkv.bias"},
63+
{IN_WATTENTION_OUT_WEIGHT, "attn.proj.weight"},
64+
{IN_WATTENTION_OUT_BIAS, "attn.proj.bias"},
65+
{IN_LINEAR_FC1_WEIGHT, "mlp.fc1.weight"},
66+
{IN_LINEAR_FC1_BIAS, "mlp.fc1.bias"},
67+
{IN_LINEAR_FC2_WEIGHT, "mlp.fc2.weight"},
68+
{IN_LINEAR_FC2_BIAS, "mlp.fc2.bias"}};
69+
70+
// {weight,dim}
71+
static std::map<int, int> WEIGHT_SHARD = {
72+
{IN_WATTENTION_OUT_WEIGHT, 1},
73+
{IN_LINEAR_FC1_WEIGHT, 0},
74+
{IN_LINEAR_FC1_BIAS, 0},
75+
{IN_LINEAR_FC2_WEIGHT, 1},
76+
};
77+
78+
Qwen2VisionEncoderLoader::Qwen2VisionEncoderLoader(uint64_t weight_count,
79+
const ModelContext& context)
80+
: BaseLoader(weight_count, context) {
81+
auto model_args = context.get_model_args();
82+
auto parallel_args = context.get_parallel_args();
83+
auto options = context.get_tensor_options();
84+
encode_param_rank = parallel_args.rank();
85+
encode_param_worldSize = parallel_args.world_size();
86+
at_weight_tensors_.resize(weight_count);
87+
dtype_ = torch::typeMetaToScalarType(options.dtype());
88+
device_id_ = options.device().index();
89+
at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_);
90+
for (int i = 0; i < weight_count; ++i) {
91+
at_weight_tensors_[i] = torch::zeros({1}).to(options);
92+
}
93+
}
94+
95+
void Qwen2VisionEncoderLoader::load_state_dict(const StateDict& state_dict) {
96+
for (const auto& [index, name] : WEIGHT_MAPPING) {
97+
if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) {
98+
set_weight(state_dict, name, index, WEIGHT_SHARD[index]);
99+
} else {
100+
set_weight(state_dict, name, index);
101+
}
102+
}
103+
}
104+
105+
void Qwen2VisionEncoderLoader::verify_loaded_weights() const {
106+
for (const auto& [index, name] : WEIGHT_MAPPING) {
107+
CHECK(at_weight_tensors_[index].sizes() != std::vector<int64_t>({1}))
108+
<< "weight is not loaded for " << name;
109+
}
110+
}
111+
112+
void Qwen2VisionEncoderLoader::merge_loaded_weights() {
113+
// spilt pack qkv weight when enable tp
114+
get_weights_col_packed_qkv();
115+
if (encode_param_worldSize > 1) {
116+
// merge qkv weight
117+
auto new_qkv_weight = torch::cat({at_weight_tensors_[IN_VISION_Q_WEIGHT],
118+
at_weight_tensors_[IN_VISION_K_WEIGHT],
119+
at_weight_tensors_[IN_VISION_V_WEIGHT]},
120+
0);
121+
at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight;
122+
at_weight_tensors_[IN_VISION_Q_WEIGHT] = torch::zeros({1}).to(device_);
123+
at_weight_tensors_[IN_VISION_K_WEIGHT] = torch::zeros({1}).to(device_);
124+
at_weight_tensors_[IN_VISION_V_WEIGHT] = torch::zeros({1}).to(device_);
125+
126+
// merge qkv bias
127+
auto new_qkv_bias = torch::cat({at_weight_tensors_[IN_VISION_Q_BIAS],
128+
at_weight_tensors_[IN_VISION_K_BIAS],
129+
at_weight_tensors_[IN_VISION_V_BIAS]},
130+
0);
131+
at_weight_tensors_[IN_QKV_BIAS] = new_qkv_bias;
132+
at_weight_tensors_[IN_VISION_Q_BIAS] = torch::zeros({1}).to(device_);
133+
at_weight_tensors_[IN_VISION_K_BIAS] = torch::zeros({1}).to(device_);
134+
at_weight_tensors_[IN_VISION_V_BIAS] = torch::zeros({1}).to(device_);
135+
}
136+
}
137+
138+
// tp spilt weight
139+
void Qwen2VisionEncoderLoader::get_weights_col_packed_qkv() {
140+
int rank = encode_param_rank;
141+
int worldSize = encode_param_worldSize;
142+
// split qkv weight
143+
qkv_weight = torch::chunk(at_weight_tensors_[IN_QKV_WEIGHT], 3, 0);
144+
qkv_bias = torch::chunk(at_weight_tensors_[IN_QKV_BIAS], 3, 0);
145+
// weight
146+
at_weight_tensors_[IN_VISION_Q_WEIGHT] =
147+
(qkv_weight[0].chunk(worldSize, 0))[rank];
148+
at_weight_tensors_[IN_VISION_K_WEIGHT] =
149+
(qkv_weight[1].chunk(worldSize, 0))[rank];
150+
at_weight_tensors_[IN_VISION_V_WEIGHT] =
151+
(qkv_weight[2].chunk(worldSize, 0))[rank];
152+
// bias
153+
at_weight_tensors_[IN_VISION_Q_BIAS] =
154+
(qkv_bias[0].chunk(worldSize, 0))[rank];
155+
at_weight_tensors_[IN_VISION_K_BIAS] =
156+
(qkv_bias[1].chunk(worldSize, 0))[rank];
157+
at_weight_tensors_[IN_VISION_V_BIAS] =
158+
(qkv_bias[2].chunk(worldSize, 0))[rank];
159+
}
160+
161+
} // namespace layer
162+
} // namespace xllm
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <map>
19+
#include <vector>
20+
21+
#include "base_loader.h"
22+
23+
namespace xllm {
24+
namespace layer {
25+
26+
class Qwen2VisionEncoderLoader : public BaseLoader {
27+
public:
28+
Qwen2VisionEncoderLoader(uint64_t weight_count, const ModelContext& context);
29+
30+
void load_state_dict(const StateDict& state_dict) override;
31+
void verify_loaded_weights() const override;
32+
void merge_loaded_weights() override;
33+
34+
private:
35+
void get_weights_col_packed_qkv();
36+
37+
protected:
38+
std::string model_name_;
39+
at::Tensor cu_seqlen_;
40+
at::Tensor at_placeholder_;
41+
std::vector<torch::Tensor> qkv_weight;
42+
std::vector<torch::Tensor> qkv_bias;
43+
int device_id_;
44+
int encode_param_rank;
45+
int encode_param_worldSize;
46+
};
47+
48+
} // namespace layer
49+
} // namespace xllm

0 commit comments

Comments
 (0)