1+ /* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+
16+ #pragma once
17+
18+ #ifdef TORCH_HIGHER_THAN_PTA6
19+ #include < torch_npu/csrc/core/npu/NPUFormat.h>
20+ #include < torch_npu/csrc/framework/OpCommand.h>
21+ #else
22+ #include < torch_npu/csrc/aten/NPUNativeFunctions.h>
23+ #include < torch_npu/csrc/framework/utils/OpPreparation.h>
24+ #endif
25+
26+ #include < torch_npu/csrc/libs/init_npu.h>
27+
28+ #include " qwen2_vision_encoder_loader.h"
29+ #include " torch_npu/csrc/core/npu/NPUCachingAllocator.h"
30+ #include " torch_npu/csrc/core/npu/NPUException.h"
31+
32+ namespace xllm {
33+ namespace layer {
34+
35+ enum VisionEncoderLayerTensorId : int {
36+ IN_INPUT_NORM_WEIGHT = 0 ,
37+ IN_INPUT_NORM_BIAS,
38+ IN_POST_NORM_WEIGHT,
39+ IN_POST_NORM_BIAS,
40+ IN_QKV_WEIGHT,
41+ IN_QKV_BIAS,
42+ IN_WATTENTION_OUT_WEIGHT,
43+ IN_WATTENTION_OUT_BIAS,
44+ IN_LINEAR_FC1_WEIGHT,
45+ IN_LINEAR_FC1_BIAS,
46+ IN_LINEAR_FC2_WEIGHT,
47+ IN_LINEAR_FC2_BIAS,
48+ IN_VISION_Q_WEIGHT,
49+ IN_VISION_Q_BIAS,
50+ IN_VISION_K_WEIGHT,
51+ IN_VISION_K_BIAS,
52+ IN_VISION_V_WEIGHT,
53+ IN_VISION_V_BIAS
54+ };
55+
56+ static std::vector<std::pair<int , std::string>> WEIGHT_MAPPING = {
57+ {IN_INPUT_NORM_WEIGHT, " norm1.weight" },
58+ {IN_INPUT_NORM_BIAS, " norm1.bias" },
59+ {IN_POST_NORM_WEIGHT, " norm2.weight" },
60+ {IN_POST_NORM_BIAS, " norm2.bias" },
61+ {IN_QKV_WEIGHT, " attn.qkv.weight" },
62+ {IN_QKV_BIAS, " attn.qkv.bias" },
63+ {IN_WATTENTION_OUT_WEIGHT, " attn.proj.weight" },
64+ {IN_WATTENTION_OUT_BIAS, " attn.proj.bias" },
65+ {IN_LINEAR_FC1_WEIGHT, " mlp.fc1.weight" },
66+ {IN_LINEAR_FC1_BIAS, " mlp.fc1.bias" },
67+ {IN_LINEAR_FC2_WEIGHT, " mlp.fc2.weight" },
68+ {IN_LINEAR_FC2_BIAS, " mlp.fc2.bias" }};
69+
70+ // {weight,dim}
71+ static std::map<int , int > WEIGHT_SHARD = {
72+ {IN_WATTENTION_OUT_WEIGHT, 1 },
73+ {IN_LINEAR_FC1_WEIGHT, 0 },
74+ {IN_LINEAR_FC1_BIAS, 0 },
75+ {IN_LINEAR_FC2_WEIGHT, 1 },
76+ };
77+
78+ Qwen2VisionEncoderLoader::Qwen2VisionEncoderLoader (uint64_t weight_count,
79+ const ModelContext& context)
80+ : BaseLoader(weight_count, context) {
81+ auto model_args = context.get_model_args ();
82+ auto parallel_args = context.get_parallel_args ();
83+ auto options = context.get_tensor_options ();
84+ encode_param_rank = parallel_args.rank ();
85+ encode_param_worldSize = parallel_args.world_size ();
86+ at_weight_tensors_.resize (weight_count);
87+ dtype_ = torch::typeMetaToScalarType (options.dtype ());
88+ device_id_ = options.device ().index ();
89+ at_placeholder_ = torch::zeros ({1 }).to (device_).to (dtype_);
90+ for (int i = 0 ; i < weight_count; ++i) {
91+ at_weight_tensors_[i] = torch::zeros ({1 }).to (options);
92+ }
93+ }
94+
95+ void Qwen2VisionEncoderLoader::load_state_dict (const StateDict& state_dict) {
96+ for (const auto & [index, name] : WEIGHT_MAPPING) {
97+ if (WEIGHT_SHARD.find (index) != WEIGHT_SHARD.end ()) {
98+ set_weight (state_dict, name, index, WEIGHT_SHARD[index]);
99+ } else {
100+ set_weight (state_dict, name, index);
101+ }
102+ }
103+ }
104+
105+ void Qwen2VisionEncoderLoader::verify_loaded_weights () const {
106+ for (const auto & [index, name] : WEIGHT_MAPPING) {
107+ CHECK (at_weight_tensors_[index].sizes () != std::vector<int64_t >({1 }))
108+ << " weight is not loaded for " << name;
109+ }
110+ }
111+
112+ void Qwen2VisionEncoderLoader::merge_loaded_weights () {
113+ // spilt pack qkv weight when enable tp
114+ get_weights_col_packed_qkv ();
115+ if (encode_param_worldSize > 1 ) {
116+ // merge qkv weight
117+ auto new_qkv_weight = torch::cat ({at_weight_tensors_[IN_VISION_Q_WEIGHT],
118+ at_weight_tensors_[IN_VISION_K_WEIGHT],
119+ at_weight_tensors_[IN_VISION_V_WEIGHT]},
120+ 0 );
121+ at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight;
122+ at_weight_tensors_[IN_VISION_Q_WEIGHT] = torch::zeros ({1 }).to (device_);
123+ at_weight_tensors_[IN_VISION_K_WEIGHT] = torch::zeros ({1 }).to (device_);
124+ at_weight_tensors_[IN_VISION_V_WEIGHT] = torch::zeros ({1 }).to (device_);
125+
126+ // merge qkv bias
127+ auto new_qkv_bias = torch::cat ({at_weight_tensors_[IN_VISION_Q_BIAS],
128+ at_weight_tensors_[IN_VISION_K_BIAS],
129+ at_weight_tensors_[IN_VISION_V_BIAS]},
130+ 0 );
131+ at_weight_tensors_[IN_QKV_BIAS] = new_qkv_bias;
132+ at_weight_tensors_[IN_VISION_Q_BIAS] = torch::zeros ({1 }).to (device_);
133+ at_weight_tensors_[IN_VISION_K_BIAS] = torch::zeros ({1 }).to (device_);
134+ at_weight_tensors_[IN_VISION_V_BIAS] = torch::zeros ({1 }).to (device_);
135+ }
136+ }
137+
138+ // tp spilt weight
139+ void Qwen2VisionEncoderLoader::get_weights_col_packed_qkv () {
140+ int rank = encode_param_rank;
141+ int worldSize = encode_param_worldSize;
142+ // split qkv weight
143+ qkv_weight = torch::chunk (at_weight_tensors_[IN_QKV_WEIGHT], 3 , 0 );
144+ qkv_bias = torch::chunk (at_weight_tensors_[IN_QKV_BIAS], 3 , 0 );
145+ // weight
146+ at_weight_tensors_[IN_VISION_Q_WEIGHT] =
147+ (qkv_weight[0 ].chunk (worldSize, 0 ))[rank];
148+ at_weight_tensors_[IN_VISION_K_WEIGHT] =
149+ (qkv_weight[1 ].chunk (worldSize, 0 ))[rank];
150+ at_weight_tensors_[IN_VISION_V_WEIGHT] =
151+ (qkv_weight[2 ].chunk (worldSize, 0 ))[rank];
152+ // bias
153+ at_weight_tensors_[IN_VISION_Q_BIAS] =
154+ (qkv_bias[0 ].chunk (worldSize, 0 ))[rank];
155+ at_weight_tensors_[IN_VISION_K_BIAS] =
156+ (qkv_bias[1 ].chunk (worldSize, 0 ))[rank];
157+ at_weight_tensors_[IN_VISION_V_BIAS] =
158+ (qkv_bias[2 ].chunk (worldSize, 0 ))[rank];
159+ }
160+
161+ } // namespace layer
162+ } // namespace xllm
0 commit comments