Skip to content

Commit e4c3aca

Browse files
committed
fix attn args (#9)
1 parent ad74481 commit e4c3aca

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def flash_context_attention(
3333
q_seq_len[i:i + 1],
3434
num_q_heads,
3535
num_kv_heads,
36-
context.attention_mask[i:i + 1],
36+
attn_mask=context.attention_mask[i:i + 1],
3737
attn_output=attn_output,
3838
)
3939
else:
@@ -51,7 +51,7 @@ def flash_context_attention(
5151
kv_seq_len[i:i + 1],
5252
num_q_heads,
5353
num_kv_heads,
54-
context.attention_mask[i:i + 1],
54+
attn_mask=context.attention_mask[i:i + 1],
5555
attn_output=attn_output,
5656
)
5757

0 commit comments

Comments
 (0)