Skip to content

Commit 9f38236

Browse files
committed
[ix] support ix device
1 parent 64eed26 commit 9f38236

File tree

9 files changed

+352
-1
lines changed

9 files changed

+352
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ endif()
1616
set(DEVICE "" CACHE STRING "device string, default empty string")
1717
string(TOLOWER "${DEVICE}" DEVICE)
1818

19-
list(APPEND SUPPORTED_DEVICE "ascend" "maca" "camb")
19+
list(APPEND SUPPORTED_DEVICE "ascend" "maca" "camb" "ix")
2020

2121
if(NOT DEVICE)
2222
message(FATAL_ERROR "Please specify variable DEVICE of dlinfer!")

dlinfer/vendor/ix/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Empty install target for ix device
2+
install(TARGETS)

dlinfer/vendor/ix/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .ix_ops import *
2+
3+
device_str = "cuda"

dlinfer/vendor/ix/ix_ops.py

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
import os
2+
import math
3+
import torch
4+
import lmdeploy.pytorch.distributed as dist
5+
6+
from dlinfer.vendor import vendor_ops_registry
7+
from dlinfer.utils.registry import register_ops
8+
from dlinfer.utils.type_annotation import Tensor, Optional, Sequence, Tuple
9+
10+
import ixformer.inference.functions as ops
11+
import ixformer.functions as ix_func
12+
from ixformer.contrib.vllm_flash_attn import (
13+
flash_attn_varlen_func as _flash_attn_varlen_func,
14+
)
15+
from ixformer.contrib.vllm_flash_attn import (
16+
flash_attn_with_kvcache as _flash_attn_with_kvcache,
17+
)
18+
19+
__all__ = [
20+
"add_rms_norm",
21+
"apply_rotary_pos_emb",
22+
"prefill_attention",
23+
"fused_moe",
24+
"fill_kv_cache",
25+
"paged_decode_attention",
26+
"paged_prefill_attention",
27+
"rms_norm",
28+
"silu_and_mul",
29+
"moe_gating_topk_softmax",
30+
"linear",
31+
"weight_quant_matmul",
32+
"dynamic_quant",
33+
"linear_w8a8",
34+
"rms_norm_w8a8",
35+
"add_rms_norm_w8a8",
36+
]
37+
38+
39+
@register_ops(vendor_ops_registry)
40+
def add_rms_norm(
41+
hidden_states: Tensor,
42+
residual: Tensor,
43+
weight: Tensor,
44+
epsilon: float,
45+
) -> Tuple[Tensor, Tensor]:
46+
return ix_func.residual_rms_norm(
47+
input=hidden_states,
48+
residual=residual,
49+
weight=weight,
50+
eps=epsilon,
51+
residual_alpha=1,
52+
)
53+
54+
55+
@register_ops(vendor_ops_registry)
56+
def apply_rotary_pos_emb(
57+
query: Tensor,
58+
key: Tensor,
59+
cos: Optional[Tensor],
60+
sin: Optional[Tensor],
61+
) -> Tuple[Tensor, Tensor]:
62+
query = query.contiguous().unsqueeze(0)
63+
key = key.contiguous().unsqueeze(0)
64+
position_ids_1d = torch.arange(0, query.size(1), device=query.device)
65+
query = query.flatten(-2, -1)
66+
key = key.flatten(-2, -1)
67+
cos = cos[..., : cos.shape[-1] // 2]
68+
sin = sin[..., : sin.shape[-1] // 2 :]
69+
cos_sin_cache = torch.cat((cos, sin), dim=-1)
70+
71+
ops.vllm_rotary_embedding(
72+
position_ids_1d, query, key, cos_sin_cache.size(-1), cos_sin_cache, True
73+
)
74+
return query, key
75+
76+
77+
@register_ops(vendor_ops_registry)
78+
def prefill_attention(
79+
query: Tensor,
80+
key: Tensor,
81+
value: Tensor,
82+
q_start_loc: Tensor,
83+
q_seq_len: Tensor,
84+
max_q_seq_len: int,
85+
num_q_heads: int,
86+
num_kv_heads: int,
87+
attn_mask: Sequence[Optional[Tensor]],
88+
softmax_scale: Optional[float],
89+
alibi_slopes: Optional[Sequence[float]],
90+
attn_output: Optional[Tensor],
91+
) -> Tensor:
92+
93+
if q_seq_len is None:
94+
q_seq_len = max_q_seq_len
95+
kv_seq_len = q_seq_len
96+
max_kv_seq_len = max_q_seq_len
97+
98+
causal = True
99+
if softmax_scale is None:
100+
softmax_scale = float(1 / math.sqrt(key.size(-1)))
101+
_flash_attn_varlen_func(
102+
q=query,
103+
k=key,
104+
v=value,
105+
cu_seqlens_q=q_start_loc,
106+
cu_seqlens_k=q_start_loc,
107+
max_seqlen_q=max_q_seq_len,
108+
max_seqlen_k=max_kv_seq_len,
109+
softmax_scale=softmax_scale,
110+
causal=causal,
111+
out=attn_output,
112+
)
113+
114+
return attn_output
115+
116+
117+
@register_ops(vendor_ops_registry)
118+
def fill_kv_cache(
119+
key: Tensor,
120+
value: Tensor,
121+
key_cache: Tensor,
122+
value_cache: Tensor,
123+
kv_indices: Tensor,
124+
k_scales_zeros: Sequence[Optional[Tensor]],
125+
v_scales_zeros: Sequence[Optional[Tensor]],
126+
quant_bits: int,
127+
) -> Tuple[Tensor, Tensor]:
128+
kv_indices = kv_indices.squeeze(-1)
129+
ops.reshape_and_cache_flash(
130+
key, value, key_cache, value_cache, kv_indices, "auto", 1.0, 1.0
131+
)
132+
return key_cache, value_cache
133+
134+
135+
@register_ops(vendor_ops_registry)
136+
def paged_decode_attention(
137+
query: Tensor,
138+
key_cache: Tensor,
139+
value_cache: Tensor,
140+
block_table: Optional[Tensor],
141+
block_size: int,
142+
kv_seq_len: Tensor,
143+
max_kv_seq_len: int,
144+
num_q_heads: int,
145+
num_kv_heads: int,
146+
softmax_scale: Optional[float],
147+
alibi_slopes: Optional[Sequence[float]],
148+
attn_output: Optional[Tensor],
149+
kv_scales: Optional[Tensor],
150+
kv_zeros: Optional[Tensor],
151+
quant_bits: Optional[int],
152+
) -> Tensor:
153+
if alibi_slopes is not None:
154+
raise RuntimeError("paged_decode_attention does not support alibi_slopes yet")
155+
156+
dim = query.size(-1)
157+
num_kv_heads = value_cache.size(1)
158+
block_size = value_cache.size(2)
159+
batch_size = block_table.size(0)
160+
161+
if softmax_scale is None:
162+
softmax_scale = float(1 / math.sqrt(query.size(-1)))
163+
164+
block_table = block_table.to(torch.int32)
165+
kv_seq_len = kv_seq_len.to(torch.int32)
166+
167+
output = torch.empty_like(query)
168+
169+
ix_func.vllm_paged_attention(
170+
output,
171+
query,
172+
key_cache,
173+
value_cache,
174+
num_kv_heads,
175+
softmax_scale,
176+
block_table,
177+
kv_seq_len.cpu(),
178+
kv_seq_len,
179+
block_size,
180+
max_kv_seq_len,
181+
None,
182+
False,
183+
need_view=False,
184+
)
185+
return output
186+
187+
188+
@register_ops(vendor_ops_registry)
189+
def paged_prefill_attention(
190+
query: Tensor,
191+
key: Tensor,
192+
value: Tensor,
193+
key_cache: Tensor,
194+
value_cache: Tensor,
195+
block_table: Tensor,
196+
block_size: int,
197+
q_start_loc: Tensor,
198+
q_seq_len: Tensor,
199+
kv_seq_len: Tensor,
200+
cu_seq_lens_kv: Tensor,
201+
max_q_seq_len: int,
202+
max_kv_seq_len: int,
203+
num_q_heads: int,
204+
num_kv_heads: int,
205+
attn_mask: Sequence[Optional[Tensor]],
206+
softmax_scale: Optional[float],
207+
alibi_slopes: Optional[Sequence[float]],
208+
attn_output: Optional[Tensor],
209+
kv_scales: Optional[Tensor],
210+
kv_zeros: Optional[Tensor],
211+
quant_bits: Optional[int],
212+
) -> Tensor:
213+
raise NotImplementedError("Not implemented on ix.")
214+
215+
216+
@register_ops(vendor_ops_registry)
217+
def rms_norm(
218+
hidden_states: Tensor,
219+
weight: Tensor,
220+
epsilon: float,
221+
) -> Tensor:
222+
input_dtype = hidden_states.dtype
223+
hidden_states = hidden_states.to(torch.float32)
224+
weight = weight.to(torch.float32)
225+
output = torch.empty_like(hidden_states)
226+
227+
ops.rms_norm(hidden_states, weight, epsilon, output)
228+
229+
return output.to(input_dtype)
230+
231+
232+
@register_ops(vendor_ops_registry)
233+
def moe_gating_topk_softmax(
234+
router_logits: Tensor, topk: int, renormalize: bool = False
235+
) -> Tuple[Tensor, Tensor]:
236+
raise NotImplementedError("Not implemented on ix.")
237+
238+
239+
@register_ops(vendor_ops_registry)
240+
def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor:
241+
d = x.shape[-1] // 2
242+
output_shape = x.shape[:-1] + (d,)
243+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
244+
245+
ops.silu_and_mul(x, out)
246+
return out
247+
248+
249+
@register_ops(vendor_ops_registry)
250+
def fused_moe(
251+
hidden_states: Tensor,
252+
gate_up_weights: Tensor,
253+
down_weights: Tensor,
254+
topk_weights: Tensor,
255+
topk_ids: Tensor,
256+
top_k: int,
257+
renormalize: bool,
258+
) -> Tensor:
259+
raise NotImplementedError("Not implemented on ix.")
260+
261+
262+
@register_ops(vendor_ops_registry)
263+
def linear(
264+
x: Tensor,
265+
weight: Tensor,
266+
bias: Optional[Tensor],
267+
all_reduce: Optional[bool],
268+
group: Optional[str],
269+
) -> Tensor:
270+
if os.getenv("DLINER_LINEAR_USE_NN_LAYOUT", "0") == "1":
271+
out = torch.matmul(x, weight)
272+
if bias is not None:
273+
out += bias
274+
else:
275+
out = torch.nn.functional.linear(x, weight, bias)
276+
if all_reduce:
277+
dist.all_reduce(out)
278+
return out
279+
280+
281+
# Quantification of W4A16 is currently supported and tested.
282+
@register_ops(vendor_ops_registry)
283+
def weight_quant_matmul(
284+
x: Tensor,
285+
qweight: Tensor,
286+
scale: Tensor,
287+
offset: Optional[Tensor] = None,
288+
bias: Optional[Tensor] = None,
289+
all_reduce: Optional[bool] = False,
290+
group_size: Optional[int] = 0,
291+
):
292+
raise NotImplementedError("Not implemented on ix.")
293+
294+
295+
@register_ops(vendor_ops_registry)
296+
def dynamic_quant(
297+
x: Tensor, quant_dtype: torch.dtype, quant_granularity: str = "PER_TOKEN"
298+
):
299+
raise NotImplementedError("Not implemented on ix.")
300+
301+
302+
@register_ops(vendor_ops_registry)
303+
def linear_w8a8(
304+
a: Tensor,
305+
b: Tensor,
306+
rms_scale: float,
307+
linear_scale: float,
308+
out_dtype: torch.dtype,
309+
quant_dtype: torch.dtype = torch.int8,
310+
bias: Tensor = None,
311+
):
312+
raise NotImplementedError("Not implemented on ix.")
313+
314+
315+
@register_ops(vendor_ops_registry)
316+
def rms_norm_w8a8(
317+
hidden_states: Tensor,
318+
weight: Tensor,
319+
epsilon: float,
320+
quant_dtype: torch.dtype = torch.int8,
321+
):
322+
raise NotImplementedError("Not implemented on ix.")
323+
324+
325+
@register_ops(vendor_ops_registry)
326+
def add_rms_norm_w8a8(
327+
hidden_states: Tensor,
328+
residual: Tensor,
329+
weight: Tensor,
330+
epsilon: float,
331+
quant_dtype: torch.dtype = torch.int8,
332+
):
333+
raise NotImplementedError("Not implemented on ix.")

requirements/ix/build.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
ninja
2+
setuptools
3+
wheel
4+
scikit-build
5+
cmake>=3.18
6+
-r torch.txt

requirements/ix/full.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
-r build.txt
2+
-r runtime.txt

requirements/ix/runtime.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
transformers
2+
-r torch.txt

requirements/ix/torch.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch
2+
torchvision

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"ascend": "PrivateUse1",
1212
"maca": "CUDA",
1313
"camb": "PrivateUse1",
14+
"ix": "CUDA",
1415
}
1516

1617

0 commit comments

Comments
 (0)