19
19
from typing import Optional , Tuple
20
20
21
21
import torch
22
+ import torch .nn .functional as F
23
+ import torch_npu
22
24
from vllm .model_executor .layers .rotary_embedding import (
23
25
DeepseekScalingRotaryEmbedding , RotaryEmbedding )
24
26
@@ -37,17 +39,18 @@ def rope_forward_oot(
37
39
query : torch .Tensor ,
38
40
key : torch .Tensor ,
39
41
offsets : Optional [torch .Tensor ] = None ,
40
- is_neox_style_override : Optional [bool ] = None
42
+ is_neox_style_override : Optional [bool ] = None ,
43
+ is_qwen_torchair : Optional [bool ] = False ,
41
44
) -> Tuple [torch .Tensor , torch .Tensor ]:
42
- if get_ascend_config ().torchair_graph_config .enabled :
45
+ if get_ascend_config (
46
+ ).torchair_graph_config .enabled and not is_qwen_torchair :
43
47
return self .forward_native (
44
48
positions ,
45
49
query ,
46
50
key ,
47
51
offsets ,
48
52
)
49
53
50
- import torch_npu
51
54
query_shape , key_shape = query .shape , key .shape
52
55
if self .cos_sin_cache .device != query .device :
53
56
self .cos_sin_cache = self .cos_sin_cache .to (query .device )
@@ -246,6 +249,92 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
246
249
self .register_buffer ("sin_cached" , sin_cached , persistent = False )
247
250
248
251
252
+ def __set_cos_sin_cache (self , seq_len , device , dtype ):
253
+ inv_freq = 1.0 / (self .base ** (torch .arange (
254
+ 0 , self .rotary_dim , 2 , device = device , dtype = torch .float32 ) *
255
+ (1 / self .rotary_dim )))
256
+ self .register_buffer ("inv_freq" , inv_freq )
257
+
258
+ t = torch .arange (self .max_position_embeddings ,
259
+ device = self .inv_freq .device ,
260
+ dtype = torch .float32 )
261
+ freqs = torch .einsum ("i,j->ij" , t , self .inv_freq )
262
+
263
+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
264
+ self .register_buffer ("cos" , emb .cos ().to (dtype = dtype ), persistent = False )
265
+ self .register_buffer ("sin" , emb .sin ().to (dtype = dtype ), persistent = False )
266
+ self .embed = F .embedding
267
+
268
+
269
+ _original_re_init = RotaryEmbedding .__init__
270
+
271
+
272
+ def qwen_rope_init_func (
273
+ self ,
274
+ head_size : int ,
275
+ rotary_dim : int ,
276
+ max_position_embeddings : int ,
277
+ base : float ,
278
+ is_neox_style : bool ,
279
+ dtype : torch .dtype ,
280
+ ) -> None :
281
+ _original_re_init (self , head_size , rotary_dim , max_position_embeddings ,
282
+ base , is_neox_style , dtype )
283
+ if get_ascend_config ().torchair_graph_config .enabled :
284
+ __set_cos_sin_cache (self ,
285
+ seq_len = max_position_embeddings ,
286
+ device = "npu" ,
287
+ dtype = dtype )
288
+
289
+
290
+ def rope_forward (
291
+ self ,
292
+ positions : torch .Tensor ,
293
+ query : torch .Tensor ,
294
+ key : torch .Tensor ,
295
+ offsets : Optional [torch .Tensor ] = None ,
296
+ is_neox_style_override : Optional [bool ] = None ,
297
+ max_seq_len : Optional [int ] = None ,
298
+ is_prefill : Optional [bool ] = True ,
299
+ is_qwen_torchair : Optional [bool ] = False ,
300
+ ):
301
+ if get_ascend_config ().torchair_graph_config .enabled \
302
+ and is_qwen_torchair and not is_prefill :
303
+ if max_seq_len is not None and torch .gt (max_seq_len ,
304
+ self .max_position_embeddings ):
305
+ __set_cos_sin_cache (self ,
306
+ seq_len = max_seq_len ,
307
+ device = query .device ,
308
+ dtype = torch .float32 )
309
+
310
+ # bsnd/bnsd
311
+ if positions is not None :
312
+ cos = self .embed (positions , self .cos )
313
+ sin = self .embed (positions , self .sin )
314
+ self .cos_embed = cos
315
+ self .sin_embed = sin
316
+ else :
317
+ cos = self .cos_embed
318
+ sin = self .sin_embed
319
+
320
+ query = query .view (* query .shape [:- 1 ], - 1 , self .head_size ).contiguous ()
321
+ key = key .view (* key .shape [:- 1 ], - 1 , self .head_size ).contiguous ()
322
+
323
+ cos = cos .unsqueeze (- 2 ).unsqueeze (- 2 )
324
+ sin = sin .unsqueeze (- 2 ).unsqueeze (- 2 )
325
+
326
+ query = query .unsqueeze (1 )
327
+ key = key .unsqueeze (1 )
328
+
329
+ q_embed , k_embed = torch_npu .npu_apply_rotary_pos_emb (
330
+ query , key , cos , sin )
331
+ return q_embed .flatten (- 2 ), k_embed .flatten (- 2 )
332
+ else :
333
+ return rope_forward_oot (self , positions , query , key , offsets ,
334
+ is_neox_style_override ,
335
+ is_qwen_torchair ) # type: ignore
336
+
337
+
249
338
def deepseek_rope_init_func (
250
339
self ,
251
340
head_size : int ,
@@ -283,7 +372,8 @@ def deepseek_rope_init_func(
283
372
device = "npu" )
284
373
285
374
286
- RotaryEmbedding .forward_oot = rope_forward_oot
375
+ RotaryEmbedding .__init__ = qwen_rope_init_func
376
+ RotaryEmbedding .forward_oot = rope_forward
287
377
288
378
# Note: we adopt the native huggingface deepseek rope initialization code from
289
379
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
0 commit comments