Skip to content

Commit 56d7818

Browse files
committed
[ix] support ix device
1 parent 64eed26 commit 56d7818

File tree

9 files changed

+345
-1
lines changed

9 files changed

+345
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ endif()
1616
set(DEVICE "" CACHE STRING "device string, default empty string")
1717
string(TOLOWER "${DEVICE}" DEVICE)
1818

19-
list(APPEND SUPPORTED_DEVICE "ascend" "maca" "camb")
19+
list(APPEND SUPPORTED_DEVICE "ascend" "maca" "camb" "ix")
2020

2121
if(NOT DEVICE)
2222
message(FATAL_ERROR "Please specify variable DEVICE of dlinfer!")

dlinfer/vendor/ix/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Empty install target for ix device
2+
install(TARGETS)

dlinfer/vendor/ix/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .ix_ops import *
2+
3+
device_str = "cuda"

dlinfer/vendor/ix/ix_ops.py

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
import os
2+
import math
3+
import vllm
4+
import torch
5+
import lmdeploy.pytorch.distributed as dist
6+
7+
from vllm import _custom_ops as custom_ops
8+
from flash_attn import flash_attn_varlen_func
9+
from vllm.model_executor.layers.fused_moe import fused_experts
10+
from vllm.attention.ops.prefix_prefill import context_attention_fwd
11+
12+
from dlinfer.vendor import vendor_ops_registry
13+
from dlinfer.utils.registry import register_ops
14+
from dlinfer.utils.type_annotation import Tensor, Optional, Sequence, Tuple
15+
16+
import ixformer.inference.functions as ops
17+
import ixformer.functions as ix_func
18+
19+
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
20+
from ixformer.contrib.vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
21+
22+
__all__ = [
23+
"add_rms_norm",
24+
"apply_rotary_pos_emb",
25+
"prefill_attention",
26+
"fused_moe",
27+
"fill_kv_cache",
28+
"paged_decode_attention",
29+
"paged_prefill_attention",
30+
"rms_norm",
31+
"silu_and_mul",
32+
"moe_gating_topk_softmax",
33+
"linear",
34+
"weight_quant_matmul",
35+
"dynamic_quant",
36+
"linear_w8a8",
37+
"rms_norm_w8a8",
38+
"add_rms_norm_w8a8",
39+
]
40+
41+
42+
@register_ops(vendor_ops_registry)
43+
def add_rms_norm(
44+
hidden_states: Tensor,
45+
residual: Tensor,
46+
weight: Tensor,
47+
epsilon: float,
48+
) -> Tuple[Tensor, Tensor]:
49+
return ix_func.residual_rms_norm(input=hidden_states, residual=residual, weight=weight, eps=epsilon, residual_alpha=1)
50+
51+
52+
@register_ops(vendor_ops_registry)
53+
def apply_rotary_pos_emb(
54+
query: Tensor,
55+
key: Tensor,
56+
cos: Optional[Tensor],
57+
sin: Optional[Tensor],
58+
) -> Tuple[Tensor, Tensor]:
59+
query = query.contiguous().unsqueeze(0)
60+
key = key.contiguous().unsqueeze(0)
61+
position_ids_1d = torch.arange(0, query.size(1), device=query.device)
62+
query = query.flatten(-2, -1)
63+
key = key.flatten(-2, -1)
64+
cos = cos[..., : cos.shape[-1] // 2]
65+
sin = sin[..., : sin.shape[-1] // 2:]
66+
cos_sin_cache = torch.cat((cos, sin), dim=-1)
67+
68+
ops.vllm_rotary_embedding(
69+
position_ids_1d, query, key, cos_sin_cache.size(-1), cos_sin_cache, True
70+
)
71+
return query, key
72+
73+
@register_ops(vendor_ops_registry)
74+
def prefill_attention(
75+
query: Tensor,
76+
key: Tensor,
77+
value: Tensor,
78+
q_start_loc: Tensor,
79+
q_seq_len: Tensor,
80+
max_q_seq_len: int,
81+
num_q_heads: int,
82+
num_kv_heads: int,
83+
attn_mask: Sequence[Optional[Tensor]],
84+
softmax_scale: Optional[float],
85+
alibi_slopes: Optional[Sequence[float]],
86+
attn_output: Optional[Tensor],
87+
) -> Tensor:
88+
89+
if q_seq_len is None:
90+
q_seq_len = max_q_seq_len
91+
kv_seq_len = q_seq_len
92+
max_kv_seq_len = max_q_seq_len
93+
94+
causal = True
95+
if softmax_scale is None:
96+
softmax_scale = float(1 / math.sqrt(key.size(-1)))
97+
_flash_attn_varlen_func(
98+
q=query,
99+
k=key,
100+
v=value,
101+
cu_seqlens_q=q_start_loc,
102+
cu_seqlens_k=q_start_loc,
103+
max_seqlen_q=max_q_seq_len,
104+
max_seqlen_k=max_kv_seq_len,
105+
softmax_scale=softmax_scale,
106+
causal=causal,
107+
out=attn_output,
108+
)
109+
110+
return attn_output
111+
112+
113+
@register_ops(vendor_ops_registry)
114+
def fill_kv_cache(
115+
key: Tensor,
116+
value: Tensor,
117+
key_cache: Tensor,
118+
value_cache: Tensor,
119+
kv_indices: Tensor,
120+
k_scales_zeros: Sequence[Optional[Tensor]],
121+
v_scales_zeros: Sequence[Optional[Tensor]],
122+
quant_bits: int,
123+
) -> Tuple[Tensor, Tensor]:
124+
kv_indices = kv_indices.squeeze(-1)
125+
ops.reshape_and_cache_flash(key, value, key_cache, value_cache, kv_indices, "auto", 1.0, 1.0)
126+
return key_cache, value_cache
127+
128+
129+
@register_ops(vendor_ops_registry)
130+
def paged_decode_attention(
131+
query: Tensor,
132+
key_cache: Tensor,
133+
value_cache: Tensor,
134+
block_table: Optional[Tensor],
135+
block_size: int,
136+
kv_seq_len: Tensor,
137+
max_kv_seq_len: int,
138+
num_q_heads: int,
139+
num_kv_heads: int,
140+
softmax_scale: Optional[float],
141+
alibi_slopes: Optional[Sequence[float]],
142+
attn_output: Optional[Tensor],
143+
kv_scales: Optional[Tensor],
144+
kv_zeros: Optional[Tensor],
145+
quant_bits: Optional[int],
146+
) -> Tensor:
147+
if alibi_slopes is not None:
148+
raise RuntimeError("paged_decode_attention does not support alibi_slopes yet")
149+
150+
dim = query.size(-1)
151+
num_kv_heads = value_cache.size(1)
152+
block_size = value_cache.size(2)
153+
batch_size = block_table.size(0)
154+
155+
if softmax_scale is None:
156+
softmax_scale = float(1 / math.sqrt(query.size(-1)))
157+
158+
block_table = block_table.to(torch.int32)
159+
kv_seq_len = kv_seq_len.to(torch.int32)
160+
161+
output = torch.empty_like(query)
162+
163+
ix_func.vllm_paged_attention(
164+
output,
165+
query,
166+
key_cache,
167+
value_cache,
168+
num_kv_heads,
169+
softmax_scale,
170+
block_table,
171+
kv_seq_len.cpu(),
172+
kv_seq_len,
173+
block_size,
174+
max_kv_seq_len,
175+
None,
176+
False,
177+
need_view=False,
178+
)
179+
return output
180+
181+
@register_ops(vendor_ops_registry)
182+
def paged_prefill_attention(
183+
query: Tensor,
184+
key: Tensor,
185+
value: Tensor,
186+
key_cache: Tensor,
187+
value_cache: Tensor,
188+
block_table: Tensor,
189+
block_size: int,
190+
q_start_loc: Tensor,
191+
q_seq_len: Tensor,
192+
kv_seq_len: Tensor,
193+
cu_seq_lens_kv: Tensor,
194+
max_q_seq_len: int,
195+
max_kv_seq_len: int,
196+
num_q_heads: int,
197+
num_kv_heads: int,
198+
attn_mask: Sequence[Optional[Tensor]],
199+
softmax_scale: Optional[float],
200+
alibi_slopes: Optional[Sequence[float]],
201+
attn_output: Optional[Tensor],
202+
kv_scales: Optional[Tensor],
203+
kv_zeros: Optional[Tensor],
204+
quant_bits: Optional[int],
205+
) -> Tensor:
206+
raise NotImplementedError("Not implemented on ix.")
207+
208+
209+
@register_ops(vendor_ops_registry)
210+
def rms_norm(
211+
hidden_states: Tensor,
212+
weight: Tensor,
213+
epsilon: float,
214+
) -> Tensor:
215+
input_dtype = hidden_states.dtype
216+
hidden_states = hidden_states.to(torch.float32)
217+
weight = weight.to(torch.float32)
218+
output = torch.empty_like(hidden_states)
219+
220+
ops.rms_norm(hidden_states, weight, epsilon, output)
221+
222+
return output.to(input_dtype)
223+
224+
225+
@register_ops(vendor_ops_registry)
226+
def moe_gating_topk_softmax(
227+
router_logits: Tensor, topk: int, renormalize: bool = False
228+
) -> Tuple[Tensor, Tensor]:
229+
raise NotImplementedError("Not implemented on ix.")
230+
231+
232+
@register_ops(vendor_ops_registry)
233+
def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor:
234+
d = x.shape[-1] // 2
235+
output_shape = x.shape[:-1] + (d,)
236+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
237+
238+
ops.silu_and_mul(x, out)
239+
return out
240+
241+
242+
@register_ops(vendor_ops_registry)
243+
def fused_moe(
244+
hidden_states: Tensor,
245+
gate_up_weights: Tensor,
246+
down_weights: Tensor,
247+
topk_weights: Tensor,
248+
topk_ids: Tensor,
249+
top_k: int,
250+
renormalize: bool,
251+
) -> Tensor:
252+
raise NotImplementedError("Not implemented on ix.")
253+
254+
255+
@register_ops(vendor_ops_registry)
256+
def linear(
257+
x: Tensor,
258+
weight: Tensor,
259+
bias: Optional[Tensor],
260+
all_reduce: Optional[bool],
261+
group: Optional[str],
262+
) -> Tensor:
263+
if os.getenv("DLINER_LINEAR_USE_NN_LAYOUT", "0") == "1":
264+
out = torch.matmul(x, weight)
265+
if bias is not None:
266+
out += bias
267+
else:
268+
out = torch.nn.functional.linear(x, weight, bias)
269+
if all_reduce:
270+
dist.all_reduce(out)
271+
return out
272+
273+
274+
# Quantification of W4A16 is currently supported and tested.
275+
@register_ops(vendor_ops_registry)
276+
def weight_quant_matmul(
277+
x: Tensor,
278+
qweight: Tensor,
279+
scale: Tensor,
280+
offset: Optional[Tensor] = None,
281+
bias: Optional[Tensor] = None,
282+
all_reduce: Optional[bool] = False,
283+
group_size: Optional[int] = 0,
284+
):
285+
raise NotImplementedError("Not implemented on ix.")
286+
287+
288+
@register_ops(vendor_ops_registry)
289+
def dynamic_quant(
290+
x: Tensor, quant_dtype: torch.dtype, quant_granularity: str = "PER_TOKEN"
291+
):
292+
raise NotImplementedError("Not implemented on ix.")
293+
294+
295+
@register_ops(vendor_ops_registry)
296+
def linear_w8a8(
297+
a: Tensor,
298+
b: Tensor,
299+
rms_scale: float,
300+
linear_scale: float,
301+
out_dtype: torch.dtype,
302+
quant_dtype: torch.dtype = torch.int8,
303+
bias: Tensor = None,
304+
):
305+
raise NotImplementedError("Not implemented on ix.")
306+
307+
308+
@register_ops(vendor_ops_registry)
309+
def rms_norm_w8a8(
310+
hidden_states: Tensor,
311+
weight: Tensor,
312+
epsilon: float,
313+
quant_dtype: torch.dtype = torch.int8,
314+
):
315+
raise NotImplementedError("Not implemented on ix.")
316+
317+
318+
@register_ops(vendor_ops_registry)
319+
def add_rms_norm_w8a8(
320+
hidden_states: Tensor,
321+
residual: Tensor,
322+
weight: Tensor,
323+
epsilon: float,
324+
quant_dtype: torch.dtype = torch.int8,
325+
):
326+
raise NotImplementedError("Not implemented on ix.")

requirements/ix/build.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
ninja
2+
setuptools
3+
wheel
4+
scikit-build
5+
cmake>=3.18
6+
-r torch.txt

requirements/ix/full.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
-r build.txt
2+
-r runtime.txt

requirements/ix/runtime.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
transformers
2+
-r torch.txt

requirements/ix/torch.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch
2+
torchvision

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"ascend": "PrivateUse1",
1212
"maca": "CUDA",
1313
"camb": "PrivateUse1",
14+
"ix": "CUDA",
1415
}
1516

1617

0 commit comments

Comments
 (0)