@@ -2044,7 +2044,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, cu_num_tokens: np.
2044
2044
indices_match = True
2045
2045
max_flattened_index = - 1
2046
2046
for req_id , cur_index in self .input_batch .req_id_to_index .items ():
2047
- if req_id in self .input_batch .prev_sampled_token_ids_invalid_indices :
2047
+ if (self .input_batch .prev_sampled_token_ids_invalid_indices is not None
2048
+ and req_id in self .input_batch .prev_sampled_token_ids_invalid_indices ):
2048
2049
# This request was in the previous batch but its
2049
2050
# prev_sampled_token_ids is invalid
2050
2051
continue
@@ -2847,6 +2848,8 @@ def execute_model(
2847
2848
req_id_to_index_output_copy = \
2848
2849
self .input_batch .req_id_to_index .copy ()
2849
2850
2851
+ max_req_index = max (self .input_batch .req_id_to_index .values ())
2852
+ postprocessed_sampled_token_ids : list [list [int ]] = [[] for _ in range (max_req_index + 1 )]
2850
2853
if self .use_async_scheduling :
2851
2854
assert not self .speculative_config , "Speculative decoding not supported with async scheduling"
2852
2855
self .input_batch .prev_sampled_token_ids = \
@@ -2859,11 +2862,7 @@ def execute_model(
2859
2862
req_id : i
2860
2863
for i , req_id in enumerate (self .input_batch .req_ids ) if i not in invalid_req_indices_set
2861
2864
}
2862
-
2863
- # For the output, create placeholder sampled_token_ids
2864
- # (will be filled during serialization)
2865
- max_req_index = max (self .input_batch .req_id_to_index .values ())
2866
- postprocessed_sampled_token_ids = [[] for _ in range (max_req_index + 1 )]
2865
+ # For the output, postprocessed_sampled_token_ids will be filled during serialization
2867
2866
else :
2868
2867
# From this point onward, all operations are done on CPU.
2869
2868
# We already have tokens. Let's copy the data to
@@ -2874,9 +2873,6 @@ def execute_model(
2874
2873
sampled_token_ids_list = torch .cat (decode_sampled_token_ids + prefill_sampled_token_ids ).tolist ()
2875
2874
sampled_token_requests = \
2876
2875
decode_sampled_requests + prefill_sampled_requests
2877
- max_req_index = max (self .input_batch .req_id_to_index .values ())
2878
- postprocessed_sampled_token_ids : list [list ]
2879
- postprocessed_sampled_token_ids = [[] for _ in range (max_req_index + 1 )]
2880
2876
# NOTE(Chendi): in post-processing, spec_decode might
2881
2877
# return more than 1 token during decode.
2882
2878
start_idx = 0
@@ -2928,12 +2924,10 @@ def execute_model(
2928
2924
req_state .output_token_ids .extend (sampled_ids )
2929
2925
2930
2926
# Create output.
2931
- all_req_ids = pd_info .decode_req_ids + pd_info .prompt_req_ids
2932
2927
# prompt_logprobs_dict: dict[
2933
2928
# str, Optional[LogprobsTensors]] = self._get_prompt_logprobs_dict(
2934
2929
# prefill_hidden_states_device, scheduler_output)
2935
2930
prompt_logprobs_dict : dict [str , Optional [LogprobsTensors ]] = {}
2936
- all_req_ids = pd_info .decode_req_ids + pd_info .prompt_req_ids
2937
2931
logprobs = None
2938
2932
2939
2933
model_runner_output = ModelRunnerOutput (
0 commit comments