diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 201250756a..223d78de76 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -543,9 +543,8 @@ def build( device=input_positions.device) input_positions = torch.cat( [input_positions, position_padding]) - actual_seq_lengths_q = query_start_loc[1:].tolist( - ) + self.runner.actual_seq_lengths_q[num_reqs:num_reqs + - num_reqs_pad_size] + actual_seq_lengths_q = actual_seq_lengths_q + self.runner.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] else: seq_lens_list = seq_lens.tolist() # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) @@ -959,7 +958,6 @@ def _forward_decode( AscendAttentionState.SpecDecoding, AscendAttentionState.ChunkedPrefill ]: - assert num_tokens % (1 + self.spec_token_num) == 0 input_layout = "TND" # [bs * q_seq_len, num_heads_per_rank, dim] q_nope = q_nope.view(num_tokens, self.num_heads, -1) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8803bc0cde..d300c1ead8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -976,6 +976,9 @@ def _process_reqs( # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly + if self.speculative_config and self.speculative_config.method == 'deepseek_mtp': + # support deepseek mtp spec decode in disaggregated-prefill scenario + attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): attn_state = AscendAttentionState.SpecDecoding diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 7a1d4d3d9d..2f3423c781 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -22,6 +22,7 @@ import torch import torch.nn as nn import torch_npu +from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions from vllm import envs from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, @@ -69,6 +70,7 @@ def __init__( # Register ops when worker init. from vllm_ascend import ops ops.register_dummy_fusion_op() + _register_atb_extensions() # init ascend config init_ascend_config(vllm_config)