1+ import  os 
2+ import  math 
3+ import  vllm 
4+ import  torch 
5+ import  lmdeploy .pytorch .distributed  as  dist 
6+ 
7+ from  vllm  import  _custom_ops  as  custom_ops 
8+ from  flash_attn  import  flash_attn_varlen_func 
9+ from  vllm .model_executor .layers .fused_moe  import  fused_experts 
10+ from  vllm .attention .ops .prefix_prefill  import  context_attention_fwd 
11+ 
12+ from  dlinfer .vendor  import  vendor_ops_registry 
13+ from  dlinfer .utils .registry  import  register_ops 
14+ from  dlinfer .utils .type_annotation  import  Tensor , Optional , Sequence , Tuple 
15+ 
16+ import  ixformer .inference .functions  as  ops 
17+ import  ixformer .functions  as  ix_func 
18+ 
19+ from  ixformer .contrib .vllm_flash_attn  import  flash_attn_varlen_func  as  _flash_attn_varlen_func 
20+ from  ixformer .contrib .vllm_flash_attn  import  flash_attn_with_kvcache  as  _flash_attn_with_kvcache 
21+ 
22+ __all__  =  [
23+     "add_rms_norm" ,
24+     "apply_rotary_pos_emb" ,
25+     "prefill_attention" ,
26+     "fused_moe" ,
27+     "fill_kv_cache" ,
28+     "paged_decode_attention" ,
29+     "paged_prefill_attention" ,
30+     "rms_norm" ,
31+     "silu_and_mul" ,
32+     "moe_gating_topk_softmax" ,
33+     "linear" ,
34+     "weight_quant_matmul" ,
35+     "dynamic_quant" ,
36+     "linear_w8a8" ,
37+     "rms_norm_w8a8" ,
38+     "add_rms_norm_w8a8" ,
39+ ]
40+ 
41+ 
42+ @register_ops (vendor_ops_registry ) 
43+ def  add_rms_norm (
44+     hidden_states : Tensor ,
45+     residual : Tensor ,
46+     weight : Tensor ,
47+     epsilon : float ,
48+ ) ->  Tuple [Tensor , Tensor ]:
49+     return  ix_func .residual_rms_norm (input = hidden_states , residual = residual , weight = weight , eps = epsilon , residual_alpha = 1 )
50+ 
51+ 
52+ @register_ops (vendor_ops_registry ) 
53+ def  apply_rotary_pos_emb (
54+     query : Tensor ,
55+     key : Tensor ,
56+     cos : Optional [Tensor ],
57+     sin : Optional [Tensor ],
58+ ) ->  Tuple [Tensor , Tensor ]:
59+     query  =  query .contiguous ().unsqueeze (0 )
60+     key  =  key .contiguous ().unsqueeze (0 )
61+     position_ids_1d  =  torch .arange (0 , query .size (1 ), device = query .device )
62+     query  =  query .flatten (- 2 , - 1 )
63+     key  =  key .flatten (- 2 , - 1 )
64+     cos  =  cos [..., : cos .shape [- 1 ] //  2 ]
65+     sin  =  sin [..., : sin .shape [- 1 ] //  2 :]
66+     cos_sin_cache  =  torch .cat ((cos , sin ), dim = - 1 )
67+ 
68+     ops .vllm_rotary_embedding (
69+         position_ids_1d , query , key , cos_sin_cache .size (- 1 ), cos_sin_cache , True 
70+     )
71+     return  query , key 
72+ 
73+ @register_ops (vendor_ops_registry ) 
74+ def  prefill_attention (
75+     query : Tensor ,
76+     key : Tensor ,
77+     value : Tensor ,
78+     q_start_loc : Tensor ,
79+     q_seq_len : Tensor ,
80+     max_q_seq_len : int ,
81+     num_q_heads : int ,
82+     num_kv_heads : int ,
83+     attn_mask : Sequence [Optional [Tensor ]],
84+     softmax_scale : Optional [float ],
85+     alibi_slopes : Optional [Sequence [float ]],
86+     attn_output : Optional [Tensor ],
87+ ) ->  Tensor :
88+ 
89+     if  q_seq_len  is  None :
90+         q_seq_len  =  max_q_seq_len 
91+     kv_seq_len  =  q_seq_len 
92+     max_kv_seq_len  =  max_q_seq_len 
93+ 
94+     causal  =  True 
95+     if  softmax_scale  is  None :
96+         softmax_scale  =  float (1  /  math .sqrt (key .size (- 1 )))
97+     _flash_attn_varlen_func (
98+         q = query ,
99+         k = key ,
100+         v = value ,
101+         cu_seqlens_q = q_start_loc ,
102+         cu_seqlens_k = q_start_loc ,
103+         max_seqlen_q = max_q_seq_len ,
104+         max_seqlen_k = max_kv_seq_len ,
105+         softmax_scale = softmax_scale ,
106+         causal = causal ,
107+         out = attn_output ,
108+     )
109+ 
110+     return  attn_output 
111+ 
112+ 
113+ @register_ops (vendor_ops_registry ) 
114+ def  fill_kv_cache (
115+     key : Tensor ,
116+     value : Tensor ,
117+     key_cache : Tensor ,
118+     value_cache : Tensor ,
119+     kv_indices : Tensor ,
120+     k_scales_zeros : Sequence [Optional [Tensor ]],
121+     v_scales_zeros : Sequence [Optional [Tensor ]],
122+     quant_bits : int ,
123+ ) ->  Tuple [Tensor , Tensor ]:
124+     kv_indices  =  kv_indices .squeeze (- 1 )
125+     ops .reshape_and_cache_flash (key , value , key_cache , value_cache , kv_indices , "auto" , 1.0 , 1.0 )
126+     return  key_cache , value_cache 
127+ 
128+ 
129+ @register_ops (vendor_ops_registry ) 
130+ def  paged_decode_attention (
131+     query : Tensor ,
132+     key_cache : Tensor ,
133+     value_cache : Tensor ,
134+     block_table : Optional [Tensor ],
135+     block_size : int ,
136+     kv_seq_len : Tensor ,
137+     max_kv_seq_len : int ,
138+     num_q_heads : int ,
139+     num_kv_heads : int ,
140+     softmax_scale : Optional [float ],
141+     alibi_slopes : Optional [Sequence [float ]],
142+     attn_output : Optional [Tensor ],
143+     kv_scales : Optional [Tensor ],
144+     kv_zeros : Optional [Tensor ],
145+     quant_bits : Optional [int ],
146+ ) ->  Tensor :
147+     if  alibi_slopes  is  not   None :
148+         raise  RuntimeError ("paged_decode_attention does not support alibi_slopes yet" )
149+ 
150+     dim  =  query .size (- 1 )
151+     num_kv_heads  =  value_cache .size (1 )
152+     block_size  =  value_cache .size (2 )
153+     batch_size  =  block_table .size (0 )
154+ 
155+     if  softmax_scale  is  None :
156+         softmax_scale  =  float (1  /  math .sqrt (query .size (- 1 )))
157+ 
158+     block_table  =  block_table .to (torch .int32 )
159+     kv_seq_len  =  kv_seq_len .to (torch .int32 )
160+ 
161+     output  =  torch .empty_like (query )
162+ 
163+     ix_func .vllm_paged_attention (
164+         output ,
165+         query ,
166+         key_cache ,
167+         value_cache ,
168+         num_kv_heads ,
169+         softmax_scale ,
170+         block_table ,
171+         kv_seq_len .cpu (),
172+         kv_seq_len ,
173+         block_size ,
174+         max_kv_seq_len ,
175+         None ,
176+         False ,
177+         need_view = False ,
178+     )
179+     return  output 
180+ 
181+ @register_ops (vendor_ops_registry ) 
182+ def  paged_prefill_attention (
183+     query : Tensor ,
184+     key : Tensor ,
185+     value : Tensor ,
186+     key_cache : Tensor ,
187+     value_cache : Tensor ,
188+     block_table : Tensor ,
189+     block_size : int ,
190+     q_start_loc : Tensor ,
191+     q_seq_len : Tensor ,
192+     kv_seq_len : Tensor ,
193+     cu_seq_lens_kv : Tensor ,
194+     max_q_seq_len : int ,
195+     max_kv_seq_len : int ,
196+     num_q_heads : int ,
197+     num_kv_heads : int ,
198+     attn_mask : Sequence [Optional [Tensor ]],
199+     softmax_scale : Optional [float ],
200+     alibi_slopes : Optional [Sequence [float ]],
201+     attn_output : Optional [Tensor ],
202+     kv_scales : Optional [Tensor ],
203+     kv_zeros : Optional [Tensor ],
204+     quant_bits : Optional [int ],
205+ ) ->  Tensor :
206+     raise  NotImplementedError ("Not implemented on ix." )
207+ 
208+ 
209+ @register_ops (vendor_ops_registry ) 
210+ def  rms_norm (
211+     hidden_states : Tensor ,
212+     weight : Tensor ,
213+     epsilon : float ,
214+ ) ->  Tensor :
215+     input_dtype  =  hidden_states .dtype 
216+     hidden_states  =  hidden_states .to (torch .float32 )
217+     weight  =  weight .to (torch .float32 )
218+     output  =  torch .empty_like (hidden_states )
219+ 
220+     ops .rms_norm (hidden_states , weight , epsilon , output )
221+ 
222+     return  output .to (input_dtype )
223+ 
224+ 
225+ @register_ops (vendor_ops_registry ) 
226+ def  moe_gating_topk_softmax (
227+     router_logits : Tensor , topk : int , renormalize : bool  =  False 
228+ ) ->  Tuple [Tensor , Tensor ]:
229+     raise  NotImplementedError ("Not implemented on ix." )
230+ 
231+ 
232+ @register_ops (vendor_ops_registry ) 
233+ def  silu_and_mul (x : Tensor , dim : int  =  - 1 ) ->  Tensor :
234+     d  =  x .shape [- 1 ] //  2 
235+     output_shape  =  x .shape [:- 1 ] +  (d ,)
236+     out  =  torch .empty (output_shape , dtype = x .dtype , device = x .device )
237+ 
238+     ops .silu_and_mul (x , out )
239+     return  out 
240+ 
241+ 
242+ @register_ops (vendor_ops_registry ) 
243+ def  fused_moe (
244+     hidden_states : Tensor ,
245+     gate_up_weights : Tensor ,
246+     down_weights : Tensor ,
247+     topk_weights : Tensor ,
248+     topk_ids : Tensor ,
249+     top_k : int ,
250+     renormalize : bool ,
251+ ) ->  Tensor :
252+     raise  NotImplementedError ("Not implemented on ix." )
253+ 
254+ 
255+ @register_ops (vendor_ops_registry ) 
256+ def  linear (
257+     x : Tensor ,
258+     weight : Tensor ,
259+     bias : Optional [Tensor ],
260+     all_reduce : Optional [bool ],
261+     group : Optional [str ],
262+ ) ->  Tensor :
263+     if  os .getenv ("DLINER_LINEAR_USE_NN_LAYOUT" , "0" ) ==  "1" :
264+         out  =  torch .matmul (x , weight )
265+         if  bias  is  not   None :
266+             out  +=  bias 
267+     else :
268+         out  =  torch .nn .functional .linear (x , weight , bias )
269+     if  all_reduce :
270+         dist .all_reduce (out )
271+     return  out 
272+ 
273+ 
274+ # Quantification of W4A16 is currently supported and tested. 
275+ @register_ops (vendor_ops_registry ) 
276+ def  weight_quant_matmul (
277+     x : Tensor ,
278+     qweight : Tensor ,
279+     scale : Tensor ,
280+     offset : Optional [Tensor ] =  None ,
281+     bias : Optional [Tensor ] =  None ,
282+     all_reduce : Optional [bool ] =  False ,
283+     group_size : Optional [int ] =  0 ,
284+ ):
285+     raise  NotImplementedError ("Not implemented on ix." )
286+ 
287+ 
288+ @register_ops (vendor_ops_registry ) 
289+ def  dynamic_quant (
290+     x : Tensor , quant_dtype : torch .dtype , quant_granularity : str  =  "PER_TOKEN" 
291+ ):
292+     raise  NotImplementedError ("Not implemented on ix." )
293+ 
294+ 
295+ @register_ops (vendor_ops_registry ) 
296+ def  linear_w8a8 (
297+     a : Tensor ,
298+     b : Tensor ,
299+     rms_scale : float ,
300+     linear_scale : float ,
301+     out_dtype : torch .dtype ,
302+     quant_dtype : torch .dtype  =  torch .int8 ,
303+     bias : Tensor  =  None ,
304+ ):
305+     raise  NotImplementedError ("Not implemented on ix." )
306+ 
307+ 
308+ @register_ops (vendor_ops_registry ) 
309+ def  rms_norm_w8a8 (
310+     hidden_states : Tensor ,
311+     weight : Tensor ,
312+     epsilon : float ,
313+     quant_dtype : torch .dtype  =  torch .int8 ,
314+ ):
315+     raise  NotImplementedError ("Not implemented on ix." )
316+ 
317+ 
318+ @register_ops (vendor_ops_registry ) 
319+ def  add_rms_norm_w8a8 (
320+     hidden_states : Tensor ,
321+     residual : Tensor ,
322+     weight : Tensor ,
323+     epsilon : float ,
324+     quant_dtype : torch .dtype  =  torch .int8 ,
325+ ):
326+     raise  NotImplementedError ("Not implemented on ix." )
0 commit comments