@@ -130,11 +130,12 @@ def get_output(self) -> ModelRunnerOutput:
130
130
valid_sampled_token_ids = self ._sampled_token_ids_cpu .tolist ()
131
131
del self ._sampled_token_ids
132
132
for i in self ._invalid_req_indices :
133
- if i < len (valid_sampled_token_ids ):
133
+ if i < len (valid_sampled_token_ids ):
134
134
valid_sampled_token_ids [i ].clear ()
135
135
136
136
output = self ._model_runner_output
137
- output .sampled_token_ids = valid_sampled_token_ids
137
+ output .sampled_token_ids [:len (valid_sampled_token_ids
138
+ )] = valid_sampled_token_ids
138
139
return output
139
140
140
141
@@ -2034,6 +2035,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
2034
2035
if self .input_batch .prev_sampled_token_ids is None :
2035
2036
return
2036
2037
2038
+ prev_sampled_token_ids = self .input_batch .prev_sampled_token_ids
2037
2039
# Async scheduling case, where some decode requests from the previous
2038
2040
# iteration won't have entries in input_ids_cpu and need to be copied
2039
2041
# on the GPU from prev_sampled_token_ids.
@@ -2044,10 +2046,11 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
2044
2046
indices_match = True
2045
2047
max_flattened_index = - 1
2046
2048
for req_id , cur_index in self .input_batch .req_id_to_index .items ():
2047
- # if req_id in self.input_batch.prev_sampled_token_ids_invalid_indices:
2048
- # # This request was in the previous batch but its
2049
- # # prev_sampled_token_ids is invalid
2050
- # continue
2049
+ if req_id in self .input_batch .\
2050
+ prev_sampled_token_ids_invalid_indices :
2051
+ # This request was in the previous batch but its
2052
+ # prev_sampled_token_ids is invalid
2053
+ continue
2051
2054
if (prev_index := prev_req_id_to_index .get (req_id )) is not None :
2052
2055
prev_common_req_indices .append (prev_index )
2053
2056
# We need to compute the flattened input_ids index of the
@@ -2061,33 +2064,30 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
2061
2064
# No requests in common with the previous iteration
2062
2065
# So input_ids_cpu will have all the input ids.
2063
2066
return
2064
- if indices_match and max_flattened_index == (
2065
- num_commmon_tokens - 1 ):
2067
+ if indices_match and max_flattened_index == (num_commmon_tokens - 1 ):
2066
2068
# Common-case optimization: the batch is unchanged
2067
2069
# and no reordering happened.
2068
2070
# The indices are both the same permutation of 0..N-1
2069
2071
self .input_ids_cpu [:len (flattened_indices )].copy_ (
2070
- self .input_batch .
2071
2072
prev_sampled_token_ids [:len (flattened_indices )])
2072
2073
return
2073
2074
2074
2075
# Upload the index tensors asynchronously
2075
2076
# so the scatter can be non-blocking
2076
2077
input_ids_index_tensor = torch .tensor (flattened_indices ,
2077
- dtype = torch .int64 ,
2078
- device = "cpu" )
2079
- prev_common_req_indices_tensor = torch . tensor (
2080
- prev_common_req_indices ,
2081
- dtype = torch . int64 ,
2082
- device = "cpu" )
2083
- src_tensor = self . input_batch . prev_sampled_token_ids
2084
- # logger.info(f"Scattering prev_common_req_indices_tensor: {prev_common_req_indices_tensor} from src_tensor: {len(src_tensor)} "
2085
- # f"to input_ids_index_tensor: {input_ids_index_tensor} ")
2078
+ dtype = torch .int64 ,
2079
+ device = "cpu" )
2080
+ if prev_sampled_token_ids . size ( 0 ) <= len ( prev_common_req_indices ):
2081
+ prev_common_req_indices = prev_common_req_indices [:
2082
+ prev_sampled_token_ids
2083
+ . size ( 0 )]
2084
+ prev_common_req_indices_tensor = torch . tensor ( prev_common_req_indices ,
2085
+ dtype = torch . int64 ,
2086
+ device = "cpu " )
2086
2087
self .input_ids_cpu .scatter_ (
2087
2088
dim = 0 ,
2088
2089
index = input_ids_index_tensor ,
2089
- src = self .input_batch .
2090
- prev_sampled_token_ids [prev_common_req_indices_tensor ])
2090
+ src = prev_sampled_token_ids [prev_common_req_indices_tensor ])
2091
2091
2092
2092
def _prepare_inputs (
2093
2093
self ,
@@ -2118,7 +2118,7 @@ def _prepare_inputs(
2118
2118
0 ,
2119
2119
torch .from_numpy (token_indices ),
2120
2120
out = self .input_ids_cpu [:total_num_scheduled_tokens ])
2121
- # Copy the tensors to the GPU.
2121
+ # Copy the tensors for async scheduling
2122
2122
self ._prepare_input_ids (total_num_scheduled_tokens , cu_num_tokens )
2123
2123
###############################################
2124
2124
@@ -2662,15 +2662,17 @@ def execute_model(
2662
2662
# If logits_indices is smaller than req_id,
2663
2663
# add the last token position
2664
2664
if logits_indices .shape [0 ] < len (req_id ):
2665
- if structured_output or self .use_async_scheduling :
2666
- logits_append = torch .tensor ([torch .sum (prompt_len ) - 1 ],
2667
- device = token_ids .device ,
2668
- dtype = torch .int32 )
2669
- logits_indices = torch .cat ([logits_indices , logits_append ])
2665
+ if structured_output :
2666
+ logits_append = torch .tensor (
2667
+ [torch .sum (prompt_len ) - 1 ],
2668
+ device = token_ids .device ,
2669
+ dtype = torch .int32 )
2670
+ logits_indices = torch .cat (
2671
+ [logits_indices , logits_append ])
2670
2672
elif self .use_async_scheduling :
2671
2673
# Discard partial prefill logits for async scheduling
2672
2674
# Depends on 1 decode token/batch
2673
- invalid_req_indices .append (num_decodes + idx )
2675
+ invalid_req_indices .append (num_decodes + idx )
2674
2676
htorch .core .mark_step ()
2675
2677
_ , sample_hidden_states , logits_device = \
2676
2678
self ._execute_model_generic (
@@ -2848,6 +2850,9 @@ def execute_model(
2848
2850
# For async scheduling: keep tokens on HPU and avoid CPU sync
2849
2851
# Concatenate decode and prefill tokens on HPU
2850
2852
if decode_sampled_token_ids or prefill_sampled_token_ids :
2853
+ decode_sampled_token_ids = [
2854
+ tensor [:num_decodes ] for tensor in decode_sampled_token_ids
2855
+ ]
2851
2856
sampled_token_ids = torch .cat (decode_sampled_token_ids +
2852
2857
prefill_sampled_token_ids ).view (
2853
2858
- 1 , 1 )
@@ -2865,13 +2870,10 @@ def execute_model(
2865
2870
if self .use_async_scheduling :
2866
2871
self .input_batch .prev_sampled_token_ids = \
2867
2872
sampled_token_ids .flatten ().to ("cpu" , non_blocking = True )
2868
-
2869
2873
# self.input_batch.prev_sampled_token_ids_invalid_indices
2870
2874
invalid_req_indices_set = set (invalid_req_indices )
2871
2875
self .input_batch .prev_sampled_token_ids_invalid_indices = \
2872
2876
invalid_req_indices_set
2873
- # logger.info(f"set: {invalid_req_indices_set}, "
2874
- # f"self.input_batch.req_ids: {self.input_batch.req_ids}, ")
2875
2877
self .input_batch .prev_req_id_to_index = {
2876
2878
req_id : i
2877
2879
for i , req_id in enumerate (self .input_batch .req_ids )
@@ -2880,9 +2882,10 @@ def execute_model(
2880
2882
2881
2883
# For the output, create placeholder sampled_token_ids
2882
2884
# (will be filled during serialization)
2883
-
2884
- postprocessed_sampled_token_ids = [[] for _ in range (num_reqs )]
2885
-
2885
+ max_req_index = max (self .input_batch .req_id_to_index .values ())
2886
+ postprocessed_sampled_token_ids = [[]
2887
+ for _ in range (max_req_index +
2888
+ 1 )]
2886
2889
else :
2887
2890
# From this point onward, all operations are done on CPU.
2888
2891
# We already have tokens. Let's copy the data to
@@ -2926,8 +2929,6 @@ def execute_model(
2926
2929
# the sampled tokens back, because there's no direct communication
2927
2930
# between the first-stage worker and the last-stage worker.
2928
2931
for req_idx , sampled_ids in enumerate (postprocessed_sampled_token_ids [:num_reqs ]):
2929
- # if self.use_async_scheduling:
2930
- # sampled_ids = [-1] # placeholder
2931
2932
if not sampled_ids :
2932
2933
continue
2933
2934
@@ -2958,7 +2959,7 @@ def execute_model(
2958
2959
logprobs = None
2959
2960
2960
2961
model_runner_output = ModelRunnerOutput (
2961
- req_ids = req_ids_output_copy , # CHECK
2962
+ req_ids = req_ids_output_copy , # CHECK
2962
2963
req_id_to_index = req_id_to_index_output_copy ,
2963
2964
sampled_token_ids = postprocessed_sampled_token_ids ,
2964
2965
logprobs = logprobs ,
@@ -2970,7 +2971,8 @@ def execute_model(
2970
2971
return AsyncHPUModelRunnerOutput (
2971
2972
model_runner_output = model_runner_output ,
2972
2973
sampled_token_ids = sampled_token_ids ,
2973
- invalid_req_indices = invalid_req_indices ,)
2974
+ invalid_req_indices = [],
2975
+ )
2974
2976
return model_runner_output
2975
2977
2976
2978
def load_model (self ) -> None :
0 commit comments