47
47
from vllm_gaudi .utils import (HPUCompileConfig , is_fake_hpu , async_h2d_copy )
48
48
from vllm_gaudi .v1 .attention .backends .hpu_attn import HPUAttentionMetadataV1
49
49
from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig , KVCacheSpec )
50
- from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput , DraftTokenIds ,
51
- LogprobsTensors , ModelRunnerOutput )
50
+ from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput , DraftTokenIds , LogprobsTensors ,
51
+ ModelRunnerOutput )
52
52
from vllm .v1 .sample .metadata import SamplingMetadata
53
53
from vllm .v1 .worker .utils import bind_kv_cache
54
54
from vllm_gaudi .v1 .worker .hpu_input_batch import InputBatch
@@ -113,8 +113,7 @@ def __init__(
113
113
self ._sampled_token_ids = sampled_token_ids
114
114
115
115
# TODO: Change to non_blocking once it is working
116
- self ._sampled_token_ids_cpu = self ._sampled_token_ids .to (
117
- 'cpu' , non_blocking = False )
116
+ self ._sampled_token_ids_cpu = self ._sampled_token_ids .to ('cpu' , non_blocking = False )
118
117
119
118
def get_output (self ) -> ModelRunnerOutput :
120
119
"""Copy the device tensors to the host and return a ModelRunnerOutput.
@@ -134,8 +133,7 @@ def get_output(self) -> ModelRunnerOutput:
134
133
valid_sampled_token_ids [i ].clear ()
135
134
136
135
output = self ._model_runner_output
137
- output .sampled_token_ids [:len (valid_sampled_token_ids
138
- )] = valid_sampled_token_ids
136
+ output .sampled_token_ids [:len (valid_sampled_token_ids )] = valid_sampled_token_ids
139
137
return output
140
138
141
139
@@ -2024,8 +2022,8 @@ def _prepare_unified_decode_inputs(self, num_decodes, num_scheduled_tokens) -> D
2024
2022
logits_indices = logits_indices_t ,
2025
2023
attn_metadata = attn_metadata ,
2026
2024
)
2027
- def _prepare_input_ids ( self , total_num_scheduled_tokens : int ,
2028
- cu_num_tokens : np .ndarray ) -> None :
2025
+
2026
+ def _prepare_input_ids ( self , total_num_scheduled_tokens : int , cu_num_tokens : np .ndarray ) -> None :
2029
2027
"""Prepare the input IDs for the current batch.
2030
2028
2031
2029
Carefully handles the `prev_sampled_token_ids` which can be cached
@@ -2046,8 +2044,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
2046
2044
indices_match = True
2047
2045
max_flattened_index = - 1
2048
2046
for req_id , cur_index in self .input_batch .req_id_to_index .items ():
2049
- if req_id in self .input_batch .\
2050
- prev_sampled_token_ids_invalid_indices :
2047
+ if req_id in self .input_batch .prev_sampled_token_ids_invalid_indices :
2051
2048
# This request was in the previous batch but its
2052
2049
# prev_sampled_token_ids is invalid
2053
2050
continue
@@ -2068,26 +2065,18 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
2068
2065
# Common-case optimization: the batch is unchanged
2069
2066
# and no reordering happened.
2070
2067
# The indices are both the same permutation of 0..N-1
2071
- self .input_ids_cpu [:len (flattened_indices )].copy_ (
2072
- prev_sampled_token_ids [:len (flattened_indices )])
2068
+ self .input_ids_cpu [:len (flattened_indices )].copy_ (prev_sampled_token_ids [:len (flattened_indices )])
2073
2069
return
2074
2070
2075
2071
# Upload the index tensors asynchronously
2076
2072
# so the scatter can be non-blocking
2077
- input_ids_index_tensor = torch .tensor (flattened_indices ,
2078
- dtype = torch .int64 ,
2079
- device = "cpu" )
2073
+ input_ids_index_tensor = torch .tensor (flattened_indices , dtype = torch .int64 , device = "cpu" )
2080
2074
if prev_sampled_token_ids .size (0 ) <= len (prev_common_req_indices ):
2081
- prev_common_req_indices = prev_common_req_indices [:
2082
- prev_sampled_token_ids
2083
- .size (0 )]
2084
- prev_common_req_indices_tensor = torch .tensor (prev_common_req_indices ,
2085
- dtype = torch .int64 ,
2086
- device = "cpu" )
2087
- self .input_ids_cpu .scatter_ (
2088
- dim = 0 ,
2089
- index = input_ids_index_tensor ,
2090
- src = prev_sampled_token_ids [prev_common_req_indices_tensor ])
2075
+ prev_common_req_indices = prev_common_req_indices [:prev_sampled_token_ids .size (0 )]
2076
+ prev_common_req_indices_tensor = torch .tensor (prev_common_req_indices , dtype = torch .int64 , device = "cpu" )
2077
+ self .input_ids_cpu .scatter_ (dim = 0 ,
2078
+ index = input_ids_index_tensor ,
2079
+ src = prev_sampled_token_ids [prev_common_req_indices_tensor ])
2091
2080
2092
2081
def _prepare_inputs (
2093
2082
self ,
@@ -2112,8 +2101,7 @@ def _prepare_inputs(
2112
2101
cu_num_tokens , arange = self ._get_cumsum_and_arange (num_scheduled_tokens )
2113
2102
np .add (self .input_batch .num_computed_tokens_cpu [req_indices ], arange , out = positions_np )
2114
2103
token_indices = (positions_np + req_indices * self .input_batch .token_ids_cpu .shape [1 ])
2115
- cu_num_tokens , arange = self ._get_cumsum_and_arange (
2116
- num_scheduled_tokens )
2104
+ cu_num_tokens , arange = self ._get_cumsum_and_arange (num_scheduled_tokens )
2117
2105
torch .index_select (self .input_batch .token_ids_cpu_tensor .flatten (),
2118
2106
0 ,
2119
2107
torch .from_numpy (token_indices ),
@@ -2663,12 +2651,10 @@ def execute_model(
2663
2651
# add the last token position
2664
2652
if logits_indices .shape [0 ] < len (req_id ):
2665
2653
if structured_output :
2666
- logits_append = torch .tensor (
2667
- [torch .sum (prompt_len ) - 1 ],
2668
- device = token_ids .device ,
2669
- dtype = torch .int32 )
2670
- logits_indices = torch .cat (
2671
- [logits_indices , logits_append ])
2654
+ logits_append = torch .tensor ([torch .sum (prompt_len ) - 1 ],
2655
+ device = token_ids .device ,
2656
+ dtype = torch .int32 )
2657
+ logits_indices = torch .cat ([logits_indices , logits_append ])
2672
2658
elif self .use_async_scheduling :
2673
2659
# Discard partial prefill logits for async scheduling
2674
2660
# Depends on 1 decode token/batch
@@ -2850,16 +2836,10 @@ def execute_model(
2850
2836
# For async scheduling: keep tokens on HPU and avoid CPU sync
2851
2837
# Concatenate decode and prefill tokens on HPU
2852
2838
if decode_sampled_token_ids or prefill_sampled_token_ids :
2853
- decode_sampled_token_ids = [
2854
- tensor [:num_decodes ] for tensor in decode_sampled_token_ids
2855
- ]
2856
- sampled_token_ids = torch .cat (decode_sampled_token_ids +
2857
- prefill_sampled_token_ids ).view (
2858
- - 1 , 1 )
2839
+ decode_sampled_token_ids = [tensor [:num_decodes ] for tensor in decode_sampled_token_ids ]
2840
+ sampled_token_ids = torch .cat (decode_sampled_token_ids + prefill_sampled_token_ids ).view (- 1 , 1 )
2859
2841
else :
2860
- sampled_token_ids = torch .empty ((0 , 1 ),
2861
- dtype = torch .int32 ,
2862
- device = self .device )
2842
+ sampled_token_ids = torch .empty ((0 , 1 ), dtype = torch .int32 , device = self .device )
2863
2843
2864
2844
# Copy some objects so they don't get modified after returning.
2865
2845
# This is important when using async scheduling.
@@ -2868,6 +2848,7 @@ def execute_model(
2868
2848
self .input_batch .req_id_to_index .copy ()
2869
2849
2870
2850
if self .use_async_scheduling :
2851
+ assert not self .speculative_config , "Speculative decoding not supported with async scheduling"
2871
2852
self .input_batch .prev_sampled_token_ids = \
2872
2853
sampled_token_ids .flatten ().to ("cpu" , non_blocking = True )
2873
2854
# self.input_batch.prev_sampled_token_ids_invalid_indices
@@ -2876,16 +2857,13 @@ def execute_model(
2876
2857
invalid_req_indices_set
2877
2858
self .input_batch .prev_req_id_to_index = {
2878
2859
req_id : i
2879
- for i , req_id in enumerate (self .input_batch .req_ids )
2880
- if i not in invalid_req_indices_set
2860
+ for i , req_id in enumerate (self .input_batch .req_ids ) if i not in invalid_req_indices_set
2881
2861
}
2882
2862
2883
2863
# For the output, create placeholder sampled_token_ids
2884
2864
# (will be filled during serialization)
2885
2865
max_req_index = max (self .input_batch .req_id_to_index .values ())
2886
- postprocessed_sampled_token_ids = [[]
2887
- for _ in range (max_req_index +
2888
- 1 )]
2866
+ postprocessed_sampled_token_ids = [[] for _ in range (max_req_index + 1 )]
2889
2867
else :
2890
2868
# From this point onward, all operations are done on CPU.
2891
2869
# We already have tokens. Let's copy the data to
0 commit comments