Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions dlinfer/vendor/ascend/pytorch_patch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Copyright (c) 2024, DeepLink. All rights reserved.
import os
import pathlib
from functools import lru_cache
from packaging import version

import torch
import torch_npu

origin_torch_compile = torch.compile
from torch_npu.contrib import transfer_to_npu
from torch_npu.utils._path_manager import PathManager

torch.compile = origin_torch_compile

Expand All @@ -20,3 +24,40 @@ def has_triton():
return False

setattr(target_module, func_str, has_triton)


# def set_atb_ops():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment for this part of code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# from torch_npu.op_plugin import atb
# for api_name in atb._atb_ops.API_LIST:
# func = getattr(torch_npu, api_name)
# setattr(torch.ops.npu, api_name, func)
# _patch_atb_and_loadso()


@lru_cache(None)
def register_atb_extensions():
npu_path = pathlib.Path(torch_npu.__file__).parents[0]
atb_so_path = os.path.join(npu_path, "lib", "libop_plugin_atb.so")
try:
PathManager.check_directory_path_readable(atb_so_path)
torch.ops.load_library(atb_so_path)
except OSError as e:
nnal_ex = None
nnal_strerror = ""
if "libatb.so" in str(e):
nnal_strerror = (
"Please check that the nnal package is installed. "
"Please run 'source set_env.sh' in the NNAL installation path."
)
if "undefined symbol" in str(e):
nnal_strerror = (
"Please check the version of the NNAL package. "
"An undefined symbol was found, "
"which may be caused by a version mismatch between NNAL and torch_npu."
)
nnal_ex = OSError(e.errno, nnal_strerror)
nnal_ex.__traceback__ = e.__traceback__
raise nnal_ex from e


register_atb_extensions()
162 changes: 100 additions & 62 deletions dlinfer/vendor/ascend/torch_npu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def prefill_attention(
if (attn_mask is None or len(attn_mask) == 0)
else attn_mask[0].to(torch.bool)
)
attn_output.view(query.shape)[:] = torch.ops.npu.npu_fusion_attention(
attn_output[:] = torch.ops.npu.npu_fusion_attention(
query,
key,
value,
Expand Down Expand Up @@ -168,10 +168,15 @@ def quant_int8(x, x_scale, x_offset):
key = quant_int8(key, k_scales_zeros[0], k_scales_zeros[1])
value = quant_int8(value, v_scales_zeros[0], v_scales_zeros[1])

key_cache_reshaped = key_cache.view(block_total, head, dim)
value_cache_reshaped = value_cache.view(block_total, head, dim)
torch.ops.npu.npu_scatter_nd_update_(key_cache_reshaped, kv_indices, key)
torch.ops.npu.npu_scatter_nd_update_(value_cache_reshaped, kv_indices, value)
is_mla = key.shape[-1] != value.shape[-1]
if is_mla:
key_cache_reshaped = key_cache.view(block_total, head, dim)
torch.ops.npu.npu_scatter_nd_update_(key_cache_reshaped, kv_indices, key)
else:
key_cache_reshaped = key_cache.view(block_total, head, dim)
value_cache_reshaped = value_cache.view(block_total, head, dim)
torch.ops.npu.npu_scatter_nd_update_(key_cache_reshaped, kv_indices, key)
torch.ops.npu.npu_scatter_nd_update_(value_cache_reshaped, kv_indices, value)
return key_cache, value_cache


Expand Down Expand Up @@ -214,38 +219,53 @@ def paged_decode_attention(
if isinstance(block_table, torch.Tensor) and block_table.dtype != torch.int32:
block_table = block_table.to(torch.int32)

bs, _, dim = query.shape
block_num = key_cache.size(0)
is_mla = key_cache.shape[-1] != value_cache.shape[-1]
query = query.contiguous()
attn_output = attn_output.contiguous()
query = query.view(bs, 1, num_q_heads * dim)
key_cache = key_cache.view(block_num, block_size, -1)
value_cache = value_cache.view(block_num, block_size, -1)
scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(dim)

torch.ops.npu_ext.npu_incre_flash_attention_v4_out(
query,
key_cache,
value_cache,
attn_output.view_as(query),
padding_mask=None,
atten_mask=None,
actual_seq_lengths=kv_seq_len.tolist(),
antiquant_scale=kv_scales,
antiquant_offset=kv_zeros,
block_table=block_table,
dequant_scale1=None,
quant_scale1=None,
dequant_scale2=None,
quant_scale2=None,
quant_offset2=None,
num_heads=num_q_heads,
scale_value=scale_value,
input_layout="BSH",
num_key_value_heads=num_kv_heads,
block_size=block_size,
inner_precise=1,
)
if is_mla:
v_head_size = value_cache.shape[-1]
torch.ops.atb._npu_paged_attention_mla(
query=query,
key_cache=key_cache,
num_kv_heads=num_kv_heads,
num_heads=num_q_heads,
scale_value=softmax_scale,
block_table=block_table,
context_lens=kv_seq_len,
mla_vheadsize=v_head_size,
out=attn_output,
)
else:
bs, _, dim = query.shape
block_num = key_cache.size(0)
query = query.view(bs, 1, num_q_heads * dim)
key_cache = key_cache.view(block_num, block_size, -1)
value_cache = value_cache.view(block_num, block_size, -1)
torch.ops.npu_ext.npu_incre_flash_attention_v4_out(
query,
key_cache,
value_cache,
attn_output.view_as(query),
padding_mask=None,
atten_mask=None,
actual_seq_lengths=kv_seq_len.tolist(),
antiquant_scale=kv_scales,
antiquant_offset=kv_zeros,
block_table=block_table,
dequant_scale1=None,
quant_scale1=None,
dequant_scale2=None,
quant_scale2=None,
quant_offset2=None,
num_heads=num_q_heads,
scale_value=scale_value,
input_layout="BSH",
num_key_value_heads=num_kv_heads,
block_size=block_size,
inner_precise=1,
)
return attn_output


Expand Down Expand Up @@ -282,35 +302,53 @@ def paged_prefill_attention(
if block_table.dtype != torch.int32:
block_table = block_table.to(torch.int32)

kv_seq_len_list = kv_seq_len.tolist()
scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(query.shape[-1])
query = query.contiguous().view(query.shape[0], 1, -1)
block_num = key_cache.size(0)
key_cache = key_cache.view(block_num, block_size, -1)
value_cache = value_cache.view(block_num, block_size, -1)
torch.ops.npu_ext.npu_incre_flash_attention_v4_out(
query,
key_cache,
value_cache,
attn_output,
padding_mask=None,
atten_mask=attn_mask[0],
actual_seq_lengths=kv_seq_len_list,
antiquant_scale=kv_scales,
antiquant_offset=kv_zeros,
block_table=block_table,
dequant_scale1=None,
quant_scale1=None,
dequant_scale2=None,
quant_scale2=None,
quant_offset2=None,
num_heads=num_q_heads,
scale_value=scale_value,
input_layout="BSH",
num_key_value_heads=num_kv_heads,
block_size=block_size,
inner_precise=1,
)
is_mla = key_cache.shape[-1] != value_cache.shape[-1]
query = query.contiguous()
attn_output = attn_output.contiguous()
scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(dim)

if is_mla:
v_head_size = value_cache.shape[-1]
torch.ops.atb._npu_paged_attention_mla(
query=query,
key_cache=key_cache,
num_kv_heads=num_kv_heads,
num_heads=num_q_heads,
scale_value=softmax_scale,
block_table=block_table,
context_lens=kv_seq_len,
mla_vheadsize=v_head_size,
out=attn_output,
)
else:
bs, _, dim = query.shape
block_num = key_cache.size(0)
query = query.view(bs, 1, num_q_heads * dim)
key_cache = key_cache.view(block_num, block_size, -1)
value_cache = value_cache.view(block_num, block_size, -1)
torch.ops.npu_ext.npu_incre_flash_attention_v4_out(
query,
key_cache,
value_cache,
attn_output,
padding_mask=None,
atten_mask=attn_mask[0],
actual_seq_lengths=kv_seq_len.tolist(),
antiquant_scale=kv_scales,
antiquant_offset=kv_zeros,
block_table=block_table,
dequant_scale1=None,
quant_scale1=None,
dequant_scale2=None,
quant_scale2=None,
quant_offset2=None,
num_heads=num_q_heads,
scale_value=scale_value,
input_layout="BSH",
num_key_value_heads=num_kv_heads,
block_size=block_size,
inner_precise=1,
)
return attn_output


Expand Down
Loading