Skip to content

Commit 0384cb2

Browse files
zhaochaoxingyangbofun
authored andcommitted
[ascend]Zcx/llama2 infer 910b (#1254)
* optimize lightllm * fix promptFlashAttention on a+x * add check for incre flash attention * add description of the added funtion
1 parent 20decf8 commit 0384cb2

File tree

13 files changed

+864
-95
lines changed

13 files changed

+864
-95
lines changed

.github/workflows/_runs-on-nv-step1.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ jobs:
7878
&& python main.py --mode gen_data" \
7979
|| ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} && git clean -xdf ${GEN_DATA} && exit 1 )
8080
source ~/Aoss_env.sh
81-
ads-cli cp s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/
81+
ads-cli cp ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/
8282
elif [[ "${GETRUNNER}" == *diopi* ]];then
8383
ssh SH1424 """
8484
set -e
@@ -87,7 +87,7 @@ jobs:
8787
srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_V100} --time=20 --gres=gpu:1 bash -c 'python main.py --mode gen_data' \
8888
|| ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} && git clean -xdf ${GEN_DATA} && exit 1 )
8989
source ~/Aoss_env.sh
90-
ads-cli cp s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/
90+
ads-cli cp ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/ s3://${Platform_ci_aoss_name}:${Platform_ci_aoss_url}@platform.aoss.cn-sh-01c.sensecoreapi-oss.cn${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/
9191
"""
9292
else
9393
ln -s ${GEN_DATA_PATH}/${GEN_DATA}/diopi ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/diopi_test/python/cache/

diopi_test/python/configs/diopi_configs.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8305,6 +8305,36 @@
83058305
),
83068306
),
83078307

8308+
'rotary_emb_v2': dict(
8309+
name=['rotary_emb_v2'],
8310+
interface=['CustomizedTest'],
8311+
dtype=[np.float32, np.float16],
8312+
para=dict(
8313+
dim=[128,]
8314+
),
8315+
tensor_para=dict(
8316+
gen_fn='Genfunc.randn',
8317+
args=[
8318+
{
8319+
"ins": ['query'],
8320+
"shape": ((8, 4096),),
8321+
},
8322+
{
8323+
"ins": ['key'],
8324+
"shape": ((8, 4096),),
8325+
},
8326+
{
8327+
"ins": ['cos'],
8328+
"shape": ((8, 1, 128),),
8329+
},
8330+
{
8331+
"ins": ['sin'],
8332+
"shape": ((8, 1, 128),),
8333+
},
8334+
],
8335+
),
8336+
),
8337+
83088338
'rms_norm_default': dict(
83098339
name=['rms_norm'],
83108340
atol=1e-4,
@@ -8551,6 +8581,134 @@
85518581
),
85528582
),
85538583

8584+
'prompt_flash_attention': dict(
8585+
name=['prompt_flash_attention'],
8586+
interface=['CustomizedTest'],
8587+
atol=1e-2,
8588+
rtol=1e-2,
8589+
para=dict(
8590+
maxInputLen=[2,],
8591+
actualSeqLengths=[[2,2],],
8592+
numHeads=[32,],
8593+
numKeyValueHeads=[32,],
8594+
dim=[128,],
8595+
),
8596+
tensor_para=dict(
8597+
args=[
8598+
{
8599+
"ins": ["query"],
8600+
"shape": ((4, 4096),),
8601+
"dtype": [np.float16,],
8602+
},
8603+
{
8604+
"ins": ["key"],
8605+
"shape": ((4, 4096),),
8606+
"dtype": [np.float16,],
8607+
},
8608+
{
8609+
"ins": ["value"],
8610+
"shape": ((4, 4096),),
8611+
"dtype": [np.float16,],
8612+
},
8613+
{
8614+
"ins": ["attenMask"],
8615+
"value": ([[[False, True],
8616+
[False, False]],
8617+
[[False, True],
8618+
[False, False]]],),
8619+
"dtype": [np.bool_,],
8620+
"gen_policy": "gen_tensor_by_value"
8621+
},
8622+
]
8623+
),
8624+
),
8625+
8626+
'paged_attention': dict(
8627+
name=['paged_attention'],
8628+
interface=['CustomizedTest'],
8629+
atol=1e-2,
8630+
rtol=1e-2,
8631+
para=dict(
8632+
actualSeqLengths=[[150,],],
8633+
numHeads=[32,],
8634+
numKeyValueHeads=[32,],
8635+
dim=[128,],
8636+
blockSize=[128,],
8637+
),
8638+
tensor_para=dict(
8639+
args=[
8640+
{
8641+
"ins": ["query"],
8642+
"shape": ((1, 4096),),
8643+
"dtype": [np.float16,],
8644+
},
8645+
{
8646+
"ins": ["key"],
8647+
"shape": ((1026, 4096),),
8648+
"dtype": [np.float16,],
8649+
},
8650+
{
8651+
"ins": ["value"],
8652+
"shape": ((1026, 4096),),
8653+
"dtype": [np.float16,],
8654+
},
8655+
{
8656+
"ins": ["blockTable"],
8657+
"value": ([[0, 1],],),
8658+
"dtype": [np.int32,],
8659+
"gen_policy": "gen_tensor_by_value"
8660+
},
8661+
]
8662+
),
8663+
),
8664+
8665+
'apply_penalty_v2': dict(
8666+
name=['apply_penalty_v2'],
8667+
interface=['CustomizedTest'],
8668+
tensor_para=dict(
8669+
args=[
8670+
{
8671+
"ins": ['logits'],
8672+
"value": ([[0.1, 0.5, 0.4, 0.3, 0.5],
8673+
[0.2, 0.4, 0.0, 0.0, 0.0],
8674+
[0.3, 0.4, 0.5, 0.3, 0.0]],),
8675+
"dtype": [np.float16, np.float32],
8676+
"gen_policy": "gen_tensor_by_value"
8677+
},
8678+
{
8679+
"ins": ["presence_penalty"],
8680+
"value": ([0.1, 0.1, 0.1, 0.1, 0.8, 0.8, 0.8, 0.8, 1.0, 1.0, 1.0],),
8681+
"dtype": [np.float16, np.float32],
8682+
"gen_policy": "gen_tensor_by_value"
8683+
},
8684+
{
8685+
"ins": ["frequency_penalty"],
8686+
"value": ([0.1, 0.1, 0.1, 0.1, 0.8, 0.8, 0.8, 0.8, 1.0, 1.0, 1.0],),
8687+
"dtype": [np.float16, np.float32],
8688+
"gen_policy": "gen_tensor_by_value"
8689+
},
8690+
{
8691+
"ins": ["repetition_penalty"],
8692+
"value": ([0.1, 0.1, 0.1, 0.1, 0.8, 0.8, 0.8, 0.8, 1.0, 1.0, 1.0],),
8693+
"dtype": [np.float16, np.float32],
8694+
"gen_policy": "gen_tensor_by_value"
8695+
},
8696+
{
8697+
"ins": ["p_token_ids"],
8698+
"value": ([0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11],),
8699+
"dtype": [np.int32, np.int32],
8700+
"gen_policy": "gen_tensor_by_value"
8701+
},
8702+
{
8703+
"ins": ["p_token_counts"],
8704+
"value": ([3, 3, 2, 2, 1, 3, 3, 3, 3, 2, 2],),
8705+
"dtype": [np.int32, np.int32],
8706+
"gen_policy": "gen_tensor_by_value"
8707+
},
8708+
]
8709+
)
8710+
),
8711+
85548712
'token_attention': dict(
85558713
name=['token_attention'],
85568714
interface=['CustomizedTest'],

diopi_test/python/conformance/customized_test.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,108 @@ def context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len):
626626
)
627627
return out
628628

629+
def prompt_flash_attention(
630+
query,
631+
key,
632+
value,
633+
attenMask,
634+
actualSeqLengths,
635+
maxInputLen,
636+
numHeads,
637+
numKeyValueHeads,
638+
dim,
639+
):
640+
bs = len(actualSeqLengths)
641+
xq = query.view(bs, maxInputLen, numHeads, dim).cuda()
642+
keys = key.view(bs, maxInputLen, numKeyValueHeads, dim).cuda()
643+
values = value.view(bs, maxInputLen, numKeyValueHeads, dim).cuda()
644+
mask = (
645+
torch.tril(torch.ones(maxInputLen, maxInputLen), diagonal=0)
646+
.unsqueeze(0)
647+
.unsqueeze(0)
648+
.cuda()
649+
)
650+
mask = mask.masked_fill(mask == 0.0, -100000000.0)
651+
mask = mask.repeat(bs, numHeads, 1, 1)
652+
xq = xq.transpose(1, 2)
653+
keys = keys.transpose(1, 2)
654+
values = values.transpose(1, 2)
655+
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(dim)
656+
scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq)
657+
out = torch.matmul(scores, values).transpose(1, 2).contiguous()
658+
return out.reshape(bs * maxInputLen, numHeads * dim)
659+
660+
def paged_attention(
661+
query,
662+
key,
663+
value,
664+
actualSeqLengths,
665+
numHeads,
666+
numKeyValueHeads,
667+
dim,
668+
blockTable,
669+
blockSize,
670+
):
671+
# q: BSH
672+
b_loc = torch.arange(key.shape[0], dtype=torch.int32).reshape(1, -1).cuda()
673+
batch = b_loc.shape[0]
674+
xq = query.view(batch, 1, numHeads, dim).transpose(1, 2).cuda()
675+
k = key.view(-1, numKeyValueHeads, dim).cuda()
676+
v = value.view(-1, numKeyValueHeads, dim).cuda()
677+
out = torch.empty([batch, numHeads, dim], device="cuda", dtype=query.dtype)
678+
max_input_len = max(actualSeqLengths)
679+
b_seq_len = torch.tensor(actualSeqLengths, dtype=torch.int32).cuda()
680+
for i in range(batch):
681+
k_loc = b_loc[i][
682+
max_input_len
683+
- b_seq_len[i]
684+
+ torch.arange(0, b_seq_len[i], device="cuda", dtype=torch.int32)
685+
]
686+
key = k[k_loc, :].view(1, b_seq_len[i], numHeads, dim).transpose(1, 2)
687+
logics = (
688+
torch.matmul(xq[i, :], key.transpose(2, 3)) / math.sqrt(dim)
689+
).reshape(numHeads, b_seq_len[i])
690+
v_loc = b_loc[i][
691+
max_input_len
692+
- b_seq_len[i]
693+
+ torch.arange(0, b_seq_len[i], device=logics.device, dtype=torch.int32)
694+
]
695+
P = logics.softmax(-1).reshape(1, numHeads, 1, b_seq_len[i])
696+
V = v[v_loc, :].view(1, b_seq_len[i], numHeads, dim).transpose(1, 2)
697+
out[i, :] = torch.matmul(P, V).view(numHeads, dim)
698+
return out.view(-1, numHeads * dim)
699+
700+
def apply_penalty_v2(
701+
logits,
702+
presence_penalty,
703+
frequency_penalty,
704+
repetition_penalty,
705+
p_token_ids,
706+
p_token_counts,
707+
):
708+
batch = logits.shape[0]
709+
logits = logits.view(-1)
710+
cur_logits = logits.index_select(0, p_token_ids)
711+
rep_logits = torch.where(
712+
cur_logits > 0,
713+
cur_logits / repetition_penalty,
714+
cur_logits * repetition_penalty,
715+
)
716+
rep_logits = rep_logits - p_token_counts * frequency_penalty - presence_penalty
717+
logits[p_token_ids] = rep_logits
718+
return logits.view(batch, -1)
719+
720+
def rotary_emb_v2(query, key, cos, sin, dim):
721+
query = query.view(query.shape[0], -1, dim)
722+
key = key.view(key.shape[0], -1, dim)
723+
q1, q2 = query.chunk(2, dim=-1)
724+
query_rotate = torch.cat((-q2, q1), dim=-1)
725+
query = query * cos + query_rotate * sin
726+
k1, k2 = key.chunk(2, dim=-1)
727+
key_rotate = torch.cat((-k2, k1), dim=-1)
728+
key = key * cos + key_rotate * sin
729+
return query.view(query.shape[0], -1), key.view(key.shape[0], -1)
730+
629731
def attention(
630732
query,
631733
key,
@@ -742,5 +844,13 @@ def attention_varlen(
742844
start_idx = cu_seqlens_q[i]
743845
end_idx = cu_seqlens_q[i + 1]
744846
actual_seq_len = end_idx - start_idx
745-
out[start_idx:end_idx, :, :] = out_paded[i, :actual_seq_len, :, :] # BSND->TND
847+
out[start_idx:end_idx, :, :] = out_paded[
848+
i, :actual_seq_len, :, :
849+
] # BSND->TND
850+
return out
851+
852+
def nll_loss_v2(input, target, weight=None, ignore_index=-100, reduction="mean"):
853+
out = torch.nn.functional.nll_loss(
854+
input, target, weight, None, ignore_index, None, reduction
855+
)
746856
return out

0 commit comments

Comments
 (0)