|
| 1 | +import os |
| 2 | +import math |
| 3 | +import torch |
| 4 | +import lmdeploy.pytorch.distributed as dist |
| 5 | + |
| 6 | +from dlinfer.vendor import vendor_ops_registry |
| 7 | +from dlinfer.utils.registry import register_ops |
| 8 | +from dlinfer.utils.type_annotation import Tensor, Optional, Sequence, Tuple |
| 9 | + |
| 10 | +import ixformer.inference.functions as ops |
| 11 | +import ixformer.functions as ix_func |
| 12 | +from ixformer.contrib.vllm_flash_attn import ( |
| 13 | + flash_attn_varlen_func as _flash_attn_varlen_func, |
| 14 | +) |
| 15 | +from ixformer.contrib.vllm_flash_attn import ( |
| 16 | + flash_attn_with_kvcache as _flash_attn_with_kvcache, |
| 17 | +) |
| 18 | + |
| 19 | +__all__ = [ |
| 20 | + "add_rms_norm", |
| 21 | + "apply_rotary_pos_emb", |
| 22 | + "prefill_attention", |
| 23 | + "fused_moe", |
| 24 | + "fill_kv_cache", |
| 25 | + "paged_decode_attention", |
| 26 | + "paged_prefill_attention", |
| 27 | + "rms_norm", |
| 28 | + "silu_and_mul", |
| 29 | + "moe_gating_topk_softmax", |
| 30 | + "linear", |
| 31 | + "weight_quant_matmul", |
| 32 | + "dynamic_quant", |
| 33 | + "linear_w8a8", |
| 34 | + "rms_norm_w8a8", |
| 35 | + "add_rms_norm_w8a8", |
| 36 | +] |
| 37 | + |
| 38 | + |
| 39 | +@register_ops(vendor_ops_registry) |
| 40 | +def add_rms_norm( |
| 41 | + hidden_states: Tensor, |
| 42 | + residual: Tensor, |
| 43 | + weight: Tensor, |
| 44 | + epsilon: float, |
| 45 | +) -> Tuple[Tensor, Tensor]: |
| 46 | + return ix_func.residual_rms_norm( |
| 47 | + input=hidden_states, |
| 48 | + residual=residual, |
| 49 | + weight=weight, |
| 50 | + eps=epsilon, |
| 51 | + residual_alpha=1, |
| 52 | + ) |
| 53 | + |
| 54 | + |
| 55 | +@register_ops(vendor_ops_registry) |
| 56 | +def apply_rotary_pos_emb( |
| 57 | + query: Tensor, |
| 58 | + key: Tensor, |
| 59 | + cos: Optional[Tensor], |
| 60 | + sin: Optional[Tensor], |
| 61 | +) -> Tuple[Tensor, Tensor]: |
| 62 | + query = query.contiguous().unsqueeze(0) |
| 63 | + key = key.contiguous().unsqueeze(0) |
| 64 | + position_ids_1d = torch.arange(0, query.size(1), device=query.device) |
| 65 | + query = query.flatten(-2, -1) |
| 66 | + key = key.flatten(-2, -1) |
| 67 | + cos = cos[..., : cos.shape[-1] // 2] |
| 68 | + sin = sin[..., : sin.shape[-1] // 2 :] |
| 69 | + cos_sin_cache = torch.cat((cos, sin), dim=-1) |
| 70 | + |
| 71 | + ops.vllm_rotary_embedding( |
| 72 | + position_ids_1d, query, key, cos_sin_cache.size(-1), cos_sin_cache, True |
| 73 | + ) |
| 74 | + return query, key |
| 75 | + |
| 76 | + |
| 77 | +@register_ops(vendor_ops_registry) |
| 78 | +def prefill_attention( |
| 79 | + query: Tensor, |
| 80 | + key: Tensor, |
| 81 | + value: Tensor, |
| 82 | + q_start_loc: Tensor, |
| 83 | + q_seq_len: Tensor, |
| 84 | + max_q_seq_len: int, |
| 85 | + num_q_heads: int, |
| 86 | + num_kv_heads: int, |
| 87 | + attn_mask: Sequence[Optional[Tensor]], |
| 88 | + softmax_scale: Optional[float], |
| 89 | + alibi_slopes: Optional[Sequence[float]], |
| 90 | + attn_output: Optional[Tensor], |
| 91 | +) -> Tensor: |
| 92 | + |
| 93 | + if q_seq_len is None: |
| 94 | + q_seq_len = max_q_seq_len |
| 95 | + kv_seq_len = q_seq_len |
| 96 | + max_kv_seq_len = max_q_seq_len |
| 97 | + |
| 98 | + causal = True |
| 99 | + if softmax_scale is None: |
| 100 | + softmax_scale = float(1 / math.sqrt(key.size(-1))) |
| 101 | + _flash_attn_varlen_func( |
| 102 | + q=query, |
| 103 | + k=key, |
| 104 | + v=value, |
| 105 | + cu_seqlens_q=q_start_loc, |
| 106 | + cu_seqlens_k=q_start_loc, |
| 107 | + max_seqlen_q=max_q_seq_len, |
| 108 | + max_seqlen_k=max_kv_seq_len, |
| 109 | + softmax_scale=softmax_scale, |
| 110 | + causal=causal, |
| 111 | + out=attn_output, |
| 112 | + ) |
| 113 | + |
| 114 | + return attn_output |
| 115 | + |
| 116 | + |
| 117 | +@register_ops(vendor_ops_registry) |
| 118 | +def fill_kv_cache( |
| 119 | + key: Tensor, |
| 120 | + value: Tensor, |
| 121 | + key_cache: Tensor, |
| 122 | + value_cache: Tensor, |
| 123 | + kv_indices: Tensor, |
| 124 | + k_scales_zeros: Sequence[Optional[Tensor]], |
| 125 | + v_scales_zeros: Sequence[Optional[Tensor]], |
| 126 | + quant_bits: int, |
| 127 | +) -> Tuple[Tensor, Tensor]: |
| 128 | + kv_indices = kv_indices.squeeze(-1) |
| 129 | + ops.reshape_and_cache_flash( |
| 130 | + key, value, key_cache, value_cache, kv_indices, "auto", 1.0, 1.0 |
| 131 | + ) |
| 132 | + return key_cache, value_cache |
| 133 | + |
| 134 | + |
| 135 | +@register_ops(vendor_ops_registry) |
| 136 | +def paged_decode_attention( |
| 137 | + query: Tensor, |
| 138 | + key_cache: Tensor, |
| 139 | + value_cache: Tensor, |
| 140 | + block_table: Optional[Tensor], |
| 141 | + block_size: int, |
| 142 | + kv_seq_len: Tensor, |
| 143 | + max_kv_seq_len: int, |
| 144 | + num_q_heads: int, |
| 145 | + num_kv_heads: int, |
| 146 | + softmax_scale: Optional[float], |
| 147 | + alibi_slopes: Optional[Sequence[float]], |
| 148 | + attn_output: Optional[Tensor], |
| 149 | + kv_scales: Optional[Tensor], |
| 150 | + kv_zeros: Optional[Tensor], |
| 151 | + quant_bits: Optional[int], |
| 152 | +) -> Tensor: |
| 153 | + if alibi_slopes is not None: |
| 154 | + raise RuntimeError("paged_decode_attention does not support alibi_slopes yet") |
| 155 | + |
| 156 | + dim = query.size(-1) |
| 157 | + num_kv_heads = value_cache.size(1) |
| 158 | + block_size = value_cache.size(2) |
| 159 | + batch_size = block_table.size(0) |
| 160 | + |
| 161 | + if softmax_scale is None: |
| 162 | + softmax_scale = float(1 / math.sqrt(query.size(-1))) |
| 163 | + |
| 164 | + block_table = block_table.to(torch.int32) |
| 165 | + kv_seq_len = kv_seq_len.to(torch.int32) |
| 166 | + |
| 167 | + output = torch.empty_like(query) |
| 168 | + |
| 169 | + ix_func.vllm_paged_attention( |
| 170 | + output, |
| 171 | + query, |
| 172 | + key_cache, |
| 173 | + value_cache, |
| 174 | + num_kv_heads, |
| 175 | + softmax_scale, |
| 176 | + block_table, |
| 177 | + kv_seq_len.cpu(), |
| 178 | + kv_seq_len, |
| 179 | + block_size, |
| 180 | + max_kv_seq_len, |
| 181 | + None, |
| 182 | + False, |
| 183 | + need_view=False, |
| 184 | + ) |
| 185 | + return output |
| 186 | + |
| 187 | + |
| 188 | +@register_ops(vendor_ops_registry) |
| 189 | +def paged_prefill_attention( |
| 190 | + query: Tensor, |
| 191 | + key: Tensor, |
| 192 | + value: Tensor, |
| 193 | + key_cache: Tensor, |
| 194 | + value_cache: Tensor, |
| 195 | + block_table: Tensor, |
| 196 | + block_size: int, |
| 197 | + q_start_loc: Tensor, |
| 198 | + q_seq_len: Tensor, |
| 199 | + kv_seq_len: Tensor, |
| 200 | + cu_seq_lens_kv: Tensor, |
| 201 | + max_q_seq_len: int, |
| 202 | + max_kv_seq_len: int, |
| 203 | + num_q_heads: int, |
| 204 | + num_kv_heads: int, |
| 205 | + attn_mask: Sequence[Optional[Tensor]], |
| 206 | + softmax_scale: Optional[float], |
| 207 | + alibi_slopes: Optional[Sequence[float]], |
| 208 | + attn_output: Optional[Tensor], |
| 209 | + kv_scales: Optional[Tensor], |
| 210 | + kv_zeros: Optional[Tensor], |
| 211 | + quant_bits: Optional[int], |
| 212 | +) -> Tensor: |
| 213 | + raise NotImplementedError("Not implemented on ix.") |
| 214 | + |
| 215 | + |
| 216 | +@register_ops(vendor_ops_registry) |
| 217 | +def rms_norm( |
| 218 | + hidden_states: Tensor, |
| 219 | + weight: Tensor, |
| 220 | + epsilon: float, |
| 221 | +) -> Tensor: |
| 222 | + input_dtype = hidden_states.dtype |
| 223 | + hidden_states = hidden_states.to(torch.float32) |
| 224 | + weight = weight.to(torch.float32) |
| 225 | + output = torch.empty_like(hidden_states) |
| 226 | + |
| 227 | + ops.rms_norm(hidden_states, weight, epsilon, output) |
| 228 | + |
| 229 | + return output.to(input_dtype) |
| 230 | + |
| 231 | + |
| 232 | +@register_ops(vendor_ops_registry) |
| 233 | +def moe_gating_topk_softmax( |
| 234 | + router_logits: Tensor, topk: int, renormalize: bool = False |
| 235 | +) -> Tuple[Tensor, Tensor]: |
| 236 | + raise NotImplementedError("Not implemented on ix.") |
| 237 | + |
| 238 | + |
| 239 | +@register_ops(vendor_ops_registry) |
| 240 | +def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor: |
| 241 | + d = x.shape[-1] // 2 |
| 242 | + output_shape = x.shape[:-1] + (d,) |
| 243 | + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) |
| 244 | + |
| 245 | + ops.silu_and_mul(x, out) |
| 246 | + return out |
| 247 | + |
| 248 | + |
| 249 | +@register_ops(vendor_ops_registry) |
| 250 | +def fused_moe( |
| 251 | + hidden_states: Tensor, |
| 252 | + gate_up_weights: Tensor, |
| 253 | + down_weights: Tensor, |
| 254 | + topk_weights: Tensor, |
| 255 | + topk_ids: Tensor, |
| 256 | + top_k: int, |
| 257 | + renormalize: bool, |
| 258 | +) -> Tensor: |
| 259 | + raise NotImplementedError("Not implemented on ix.") |
| 260 | + |
| 261 | + |
| 262 | +@register_ops(vendor_ops_registry) |
| 263 | +def linear( |
| 264 | + x: Tensor, |
| 265 | + weight: Tensor, |
| 266 | + bias: Optional[Tensor], |
| 267 | + all_reduce: Optional[bool], |
| 268 | + group: Optional[str], |
| 269 | +) -> Tensor: |
| 270 | + if os.getenv("DLINER_LINEAR_USE_NN_LAYOUT", "0") == "1": |
| 271 | + out = torch.matmul(x, weight) |
| 272 | + if bias is not None: |
| 273 | + out += bias |
| 274 | + else: |
| 275 | + out = torch.nn.functional.linear(x, weight, bias) |
| 276 | + if all_reduce: |
| 277 | + dist.all_reduce(out) |
| 278 | + return out |
| 279 | + |
| 280 | + |
| 281 | +# Quantification of W4A16 is currently supported and tested. |
| 282 | +@register_ops(vendor_ops_registry) |
| 283 | +def weight_quant_matmul( |
| 284 | + x: Tensor, |
| 285 | + qweight: Tensor, |
| 286 | + scale: Tensor, |
| 287 | + offset: Optional[Tensor] = None, |
| 288 | + bias: Optional[Tensor] = None, |
| 289 | + all_reduce: Optional[bool] = False, |
| 290 | + group_size: Optional[int] = 0, |
| 291 | +): |
| 292 | + raise NotImplementedError("Not implemented on ix.") |
| 293 | + |
| 294 | + |
| 295 | +@register_ops(vendor_ops_registry) |
| 296 | +def dynamic_quant( |
| 297 | + x: Tensor, quant_dtype: torch.dtype, quant_granularity: str = "PER_TOKEN" |
| 298 | +): |
| 299 | + raise NotImplementedError("Not implemented on ix.") |
| 300 | + |
| 301 | + |
| 302 | +@register_ops(vendor_ops_registry) |
| 303 | +def linear_w8a8( |
| 304 | + a: Tensor, |
| 305 | + b: Tensor, |
| 306 | + rms_scale: float, |
| 307 | + linear_scale: float, |
| 308 | + out_dtype: torch.dtype, |
| 309 | + quant_dtype: torch.dtype = torch.int8, |
| 310 | + bias: Tensor = None, |
| 311 | +): |
| 312 | + raise NotImplementedError("Not implemented on ix.") |
| 313 | + |
| 314 | + |
| 315 | +@register_ops(vendor_ops_registry) |
| 316 | +def rms_norm_w8a8( |
| 317 | + hidden_states: Tensor, |
| 318 | + weight: Tensor, |
| 319 | + epsilon: float, |
| 320 | + quant_dtype: torch.dtype = torch.int8, |
| 321 | +): |
| 322 | + raise NotImplementedError("Not implemented on ix.") |
| 323 | + |
| 324 | + |
| 325 | +@register_ops(vendor_ops_registry) |
| 326 | +def add_rms_norm_w8a8( |
| 327 | + hidden_states: Tensor, |
| 328 | + residual: Tensor, |
| 329 | + weight: Tensor, |
| 330 | + epsilon: float, |
| 331 | + quant_dtype: torch.dtype = torch.int8, |
| 332 | +): |
| 333 | + raise NotImplementedError("Not implemented on ix.") |
0 commit comments