19
19
from typing import Optional , Tuple
20
20
21
21
import torch
22
+ import torch .nn .functional as F
23
+ import torch_npu
22
24
from vllm .model_executor .layers .rotary_embedding import (
23
25
DeepseekScalingRotaryEmbedding , RotaryEmbedding )
24
26
@@ -37,17 +39,18 @@ def rope_forward_oot(
37
39
query : torch .Tensor ,
38
40
key : torch .Tensor ,
39
41
offsets : Optional [torch .Tensor ] = None ,
40
- is_neox_style_override : Optional [bool ] = None
42
+ is_neox_style_override : Optional [bool ] = None ,
43
+ is_qwen_torchair : Optional [bool ] = False ,
41
44
) -> Tuple [torch .Tensor , torch .Tensor ]:
42
- if get_ascend_config ().torchair_graph_config .enabled :
45
+ if get_ascend_config (
46
+ ).torchair_graph_config .enabled and not is_qwen_torchair :
43
47
return self .forward_native (
44
48
positions ,
45
49
query ,
46
50
key ,
47
51
offsets ,
48
52
)
49
53
50
- import torch_npu
51
54
query_shape , key_shape = query .shape , key .shape
52
55
if self .cos_sin_cache .device != query .device :
53
56
self .cos_sin_cache = self .cos_sin_cache .to (query .device )
@@ -246,6 +249,98 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
246
249
self .register_buffer ("sin_cached" , sin_cached , persistent = False )
247
250
248
251
252
+ def __set_cos_sin_cache (self , seq_len , device , dtype ):
253
+ inv_freq = 1.0 / (self .base ** (torch .arange (
254
+ 0 , self .rotary_dim , 2 , device = device , dtype = torch .float32 ) *
255
+ (1 / self .rotary_dim )))
256
+ self .register_buffer ("inv_freq" , inv_freq )
257
+
258
+ t = torch .arange (self .max_position_embeddings ,
259
+ device = self .inv_freq .device ,
260
+ dtype = torch .float32 )
261
+ freqs = torch .einsum ("i,j->ij" , t , self .inv_freq )
262
+
263
+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
264
+ self .register_buffer ("cos" , emb .cos ().to (dtype = dtype ), persistent = False )
265
+ self .register_buffer ("sin" , emb .sin ().to (dtype = dtype ), persistent = False )
266
+ self .embed = F .embedding
267
+
268
+
269
+ def qwen_rope_init_func (
270
+ self ,
271
+ head_size : int ,
272
+ rotary_dim : int ,
273
+ max_position_embeddings : int ,
274
+ base : float ,
275
+ is_neox_style : bool ,
276
+ dtype : torch .dtype ,
277
+ ) -> None :
278
+ super (RotaryEmbedding , self ).__init__ ()
279
+ self .head_size = head_size
280
+ self .rotary_dim = rotary_dim
281
+ self .max_position_embeddings = max_position_embeddings
282
+ self .base = base
283
+ self .is_neox_style = is_neox_style
284
+ self .dtype = dtype
285
+
286
+ cache = self ._compute_cos_sin_cache ()
287
+ cache = cache .to (dtype )
288
+ self .cos_sin_cache : torch .Tensor # type: ignore[misc]
289
+ self .register_buffer ("cos_sin_cache" , cache , persistent = False )
290
+ if get_ascend_config ().torchair_graph_config .enabled :
291
+ __set_cos_sin_cache (self ,
292
+ seq_len = max_position_embeddings ,
293
+ device = "npu" ,
294
+ dtype = dtype )
295
+
296
+
297
+ def rope_forward (
298
+ self ,
299
+ positions : torch .Tensor ,
300
+ query : torch .Tensor ,
301
+ key : torch .Tensor ,
302
+ offsets : Optional [torch .Tensor ] = None ,
303
+ is_neox_style_override : Optional [bool ] = None ,
304
+ max_seq_len : Optional [int ] = None ,
305
+ is_prefill : Optional [bool ] = True ,
306
+ is_qwen_torchair : Optional [bool ] = False ,
307
+ ):
308
+ if (not get_ascend_config ().torchair_graph_config .enabled
309
+ or not is_qwen_torchair or is_prefill ):
310
+ return rope_forward_oot (self , positions , query , key , offsets ,
311
+ is_neox_style_override ,
312
+ is_qwen_torchair ) # type: ignore
313
+
314
+ if max_seq_len is not None and torch .gt (max_seq_len ,
315
+ self .max_position_embeddings ):
316
+ __set_cos_sin_cache (self ,
317
+ seq_len = max_seq_len ,
318
+ device = query .device ,
319
+ dtype = torch .float32 )
320
+
321
+ # bsnd/bnsd
322
+ if positions is not None :
323
+ cos = self .embed (positions , self .cos )
324
+ sin = self .embed (positions , self .sin )
325
+ self .cos_embed = cos
326
+ self .sin_embed = sin
327
+ else :
328
+ cos = self .cos_embed
329
+ sin = self .sin_embed
330
+
331
+ query = query .view (* query .shape [:- 1 ], - 1 , self .head_size ).contiguous ()
332
+ key = key .view (* key .shape [:- 1 ], - 1 , self .head_size ).contiguous ()
333
+
334
+ cos = cos .unsqueeze (- 2 ).unsqueeze (- 2 )
335
+ sin = sin .unsqueeze (- 2 ).unsqueeze (- 2 )
336
+
337
+ query = query .unsqueeze (1 )
338
+ key = key .unsqueeze (1 )
339
+
340
+ q_embed , k_embed = torch_npu .npu_apply_rotary_pos_emb (query , key , cos , sin )
341
+ return q_embed .flatten (- 2 ), k_embed .flatten (- 2 )
342
+
343
+
249
344
def deepseek_rope_init_func (
250
345
self ,
251
346
head_size : int ,
@@ -283,7 +378,8 @@ def deepseek_rope_init_func(
283
378
device = "npu" )
284
379
285
380
286
- RotaryEmbedding .forward_oot = rope_forward_oot
381
+ RotaryEmbedding .__init__ = qwen_rope_init_func
382
+ RotaryEmbedding .forward_oot = rope_forward
287
383
288
384
# Note: we adopt the native huggingface deepseek rope initialization code from
289
385
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
0 commit comments