Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,9 +543,8 @@ def build(
device=input_positions.device)
input_positions = torch.cat(
[input_positions, position_padding])
actual_seq_lengths_q = query_start_loc[1:].tolist(
) + self.runner.actual_seq_lengths_q[num_reqs:num_reqs +
num_reqs_pad_size]
actual_seq_lengths_q = actual_seq_lengths_q + self.runner.actual_seq_lengths_q[
num_reqs:num_reqs + num_reqs_pad_size]
else:
seq_lens_list = seq_lens.tolist()
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
Expand Down Expand Up @@ -959,7 +958,6 @@ def _forward_decode(
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill
]:
assert num_tokens % (1 + self.spec_token_num) == 0
input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
Expand Down
3 changes: 3 additions & 0 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,9 @@ def _process_reqs(
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
# support deepseek mtp spec decode in disaggregated-prefill scenario
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
attn_state = AscendAttentionState.SpecDecoding
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
import torch.nn as nn
import torch_npu
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
# Register ops when worker init.
from vllm_ascend import ops
ops.register_dummy_fusion_op()
_register_atb_extensions()
# init ascend config
init_ascend_config(vllm_config)

Expand Down