diff --git a/dlinfer/vendor/ascend/pytorch_patch.py b/dlinfer/vendor/ascend/pytorch_patch.py index e5dc164a..9c9cd9eb 100644 --- a/dlinfer/vendor/ascend/pytorch_patch.py +++ b/dlinfer/vendor/ascend/pytorch_patch.py @@ -1,4 +1,7 @@ # Copyright (c) 2024, DeepLink. All rights reserved. +import os +import pathlib +from functools import lru_cache from packaging import version import torch @@ -6,6 +9,7 @@ origin_torch_compile = torch.compile from torch_npu.contrib import transfer_to_npu +from torch_npu.utils._path_manager import PathManager torch.compile = origin_torch_compile @@ -20,3 +24,41 @@ def has_triton(): return False setattr(target_module, func_str, has_triton) + + +# This may be used in aclgraph in the future +# def set_atb_ops(): +# from torch_npu.op_plugin import atb +# for api_name in atb._atb_ops.API_LIST: +# func = getattr(torch_npu, api_name) +# setattr(torch.ops.npu, api_name, func) +# _patch_atb_and_loadso() + + +@lru_cache(None) +def register_atb_extensions(): + npu_path = pathlib.Path(torch_npu.__file__).parents[0] + atb_so_path = os.path.join(npu_path, "lib", "libop_plugin_atb.so") + try: + PathManager.check_directory_path_readable(atb_so_path) + torch.ops.load_library(atb_so_path) + except OSError as e: + nnal_ex = None + nnal_strerror = "" + if "libatb.so" in str(e): + nnal_strerror = ( + "Please check that the nnal package is installed. " + "Please run 'source set_env.sh' in the NNAL installation path." + ) + if "undefined symbol" in str(e): + nnal_strerror = ( + "Please check the version of the NNAL package. " + "An undefined symbol was found, " + "which may be caused by a version mismatch between NNAL and torch_npu." + ) + nnal_ex = OSError(e.errno, nnal_strerror) + nnal_ex.__traceback__ = e.__traceback__ + raise nnal_ex from e + + +register_atb_extensions() diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index b9035d86..c39e1789 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -105,7 +105,7 @@ def prefill_attention( if (attn_mask is None or len(attn_mask) == 0) else attn_mask[0].to(torch.bool) ) - attn_output.view(query.shape)[:] = torch.ops.npu.npu_fusion_attention( + attn_output[:] = torch.ops.npu.npu_fusion_attention( query, key, value, @@ -168,10 +168,15 @@ def quant_int8(x, x_scale, x_offset): key = quant_int8(key, k_scales_zeros[0], k_scales_zeros[1]) value = quant_int8(value, v_scales_zeros[0], v_scales_zeros[1]) - key_cache_reshaped = key_cache.view(block_total, head, dim) - value_cache_reshaped = value_cache.view(block_total, head, dim) - torch.ops.npu.npu_scatter_nd_update_(key_cache_reshaped, kv_indices, key) - torch.ops.npu.npu_scatter_nd_update_(value_cache_reshaped, kv_indices, value) + is_mla = key.shape[-1] != value.shape[-1] + if is_mla: + key_cache_reshaped = key_cache.view(block_total, head, dim) + torch.ops.npu.npu_scatter_nd_update_(key_cache_reshaped, kv_indices, key) + else: + key_cache_reshaped = key_cache.view(block_total, head, dim) + value_cache_reshaped = value_cache.view(block_total, head, dim) + torch.ops.npu.npu_scatter_nd_update_(key_cache_reshaped, kv_indices, key) + torch.ops.npu.npu_scatter_nd_update_(value_cache_reshaped, kv_indices, value) return key_cache, value_cache @@ -214,38 +219,53 @@ def paged_decode_attention( if isinstance(block_table, torch.Tensor) and block_table.dtype != torch.int32: block_table = block_table.to(torch.int32) - bs, _, dim = query.shape - block_num = key_cache.size(0) + is_mla = key_cache.shape[-1] != value_cache.shape[-1] query = query.contiguous() attn_output = attn_output.contiguous() - query = query.view(bs, 1, num_q_heads * dim) - key_cache = key_cache.view(block_num, block_size, -1) - value_cache = value_cache.view(block_num, block_size, -1) scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(dim) - torch.ops.npu_ext.npu_incre_flash_attention_v4_out( - query, - key_cache, - value_cache, - attn_output.view_as(query), - padding_mask=None, - atten_mask=None, - actual_seq_lengths=kv_seq_len.tolist(), - antiquant_scale=kv_scales, - antiquant_offset=kv_zeros, - block_table=block_table, - dequant_scale1=None, - quant_scale1=None, - dequant_scale2=None, - quant_scale2=None, - quant_offset2=None, - num_heads=num_q_heads, - scale_value=scale_value, - input_layout="BSH", - num_key_value_heads=num_kv_heads, - block_size=block_size, - inner_precise=1, - ) + if is_mla: + v_head_size = value_cache.shape[-1] + torch.ops.atb._npu_paged_attention_mla( + query=query, + key_cache=key_cache, + num_kv_heads=num_kv_heads, + num_heads=num_q_heads, + scale_value=softmax_scale, + block_table=block_table, + context_lens=kv_seq_len, + mla_vheadsize=v_head_size, + out=attn_output, + ) + else: + bs, _, dim = query.shape + block_num = key_cache.size(0) + query = query.view(bs, 1, num_q_heads * dim) + key_cache = key_cache.view(block_num, block_size, -1) + value_cache = value_cache.view(block_num, block_size, -1) + torch.ops.npu_ext.npu_incre_flash_attention_v4_out( + query, + key_cache, + value_cache, + attn_output.view_as(query), + padding_mask=None, + atten_mask=None, + actual_seq_lengths=kv_seq_len.tolist(), + antiquant_scale=kv_scales, + antiquant_offset=kv_zeros, + block_table=block_table, + dequant_scale1=None, + quant_scale1=None, + dequant_scale2=None, + quant_scale2=None, + quant_offset2=None, + num_heads=num_q_heads, + scale_value=scale_value, + input_layout="BSH", + num_key_value_heads=num_kv_heads, + block_size=block_size, + inner_precise=1, + ) return attn_output @@ -282,35 +302,53 @@ def paged_prefill_attention( if block_table.dtype != torch.int32: block_table = block_table.to(torch.int32) - kv_seq_len_list = kv_seq_len.tolist() - scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(query.shape[-1]) - query = query.contiguous().view(query.shape[0], 1, -1) - block_num = key_cache.size(0) - key_cache = key_cache.view(block_num, block_size, -1) - value_cache = value_cache.view(block_num, block_size, -1) - torch.ops.npu_ext.npu_incre_flash_attention_v4_out( - query, - key_cache, - value_cache, - attn_output, - padding_mask=None, - atten_mask=attn_mask[0], - actual_seq_lengths=kv_seq_len_list, - antiquant_scale=kv_scales, - antiquant_offset=kv_zeros, - block_table=block_table, - dequant_scale1=None, - quant_scale1=None, - dequant_scale2=None, - quant_scale2=None, - quant_offset2=None, - num_heads=num_q_heads, - scale_value=scale_value, - input_layout="BSH", - num_key_value_heads=num_kv_heads, - block_size=block_size, - inner_precise=1, - ) + is_mla = key_cache.shape[-1] != value_cache.shape[-1] + query = query.contiguous() + attn_output = attn_output.contiguous() + scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(dim) + + if is_mla: + v_head_size = value_cache.shape[-1] + torch.ops.atb._npu_paged_attention_mla( + query=query, + key_cache=key_cache, + num_kv_heads=num_kv_heads, + num_heads=num_q_heads, + scale_value=softmax_scale, + block_table=block_table, + context_lens=kv_seq_len, + mla_vheadsize=v_head_size, + out=attn_output, + ) + else: + bs, _, dim = query.shape + block_num = key_cache.size(0) + query = query.view(bs, 1, num_q_heads * dim) + key_cache = key_cache.view(block_num, block_size, -1) + value_cache = value_cache.view(block_num, block_size, -1) + torch.ops.npu_ext.npu_incre_flash_attention_v4_out( + query, + key_cache, + value_cache, + attn_output, + padding_mask=None, + atten_mask=attn_mask[0], + actual_seq_lengths=kv_seq_len.tolist(), + antiquant_scale=kv_scales, + antiquant_offset=kv_zeros, + block_table=block_table, + dequant_scale1=None, + quant_scale1=None, + dequant_scale2=None, + quant_scale2=None, + quant_offset2=None, + num_heads=num_q_heads, + scale_value=scale_value, + input_layout="BSH", + num_key_value_heads=num_kv_heads, + block_size=block_size, + inner_precise=1, + ) return attn_output