Skip to content

Commit 68609db

Browse files
committed
more pre-commit fix
Signed-off-by: Tianmu Li <[email protected]>
1 parent 3032bdd commit 68609db

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,7 +2044,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, cu_num_tokens: np.
20442044
indices_match = True
20452045
max_flattened_index = -1
20462046
for req_id, cur_index in self.input_batch.req_id_to_index.items():
2047-
if req_id in self.input_batch.prev_sampled_token_ids_invalid_indices:
2047+
if (self.input_batch.prev_sampled_token_ids_invalid_indices is not None
2048+
and req_id in self.input_batch.prev_sampled_token_ids_invalid_indices):
20482049
# This request was in the previous batch but its
20492050
# prev_sampled_token_ids is invalid
20502051
continue
@@ -2847,6 +2848,8 @@ def execute_model(
28472848
req_id_to_index_output_copy = \
28482849
self.input_batch.req_id_to_index.copy()
28492850

2851+
max_req_index = max(self.input_batch.req_id_to_index.values())
2852+
postprocessed_sampled_token_ids: list[list[int]] = [[] for _ in range(max_req_index + 1)]
28502853
if self.use_async_scheduling:
28512854
assert not self.speculative_config, "Speculative decoding not supported with async scheduling"
28522855
self.input_batch.prev_sampled_token_ids = \
@@ -2859,11 +2862,7 @@ def execute_model(
28592862
req_id: i
28602863
for i, req_id in enumerate(self.input_batch.req_ids) if i not in invalid_req_indices_set
28612864
}
2862-
2863-
# For the output, create placeholder sampled_token_ids
2864-
# (will be filled during serialization)
2865-
max_req_index = max(self.input_batch.req_id_to_index.values())
2866-
postprocessed_sampled_token_ids = [[] for _ in range(max_req_index + 1)]
2865+
# For the output, postprocessed_sampled_token_ids will be filled during serialization
28672866
else:
28682867
# From this point onward, all operations are done on CPU.
28692868
# We already have tokens. Let's copy the data to
@@ -2874,9 +2873,6 @@ def execute_model(
28742873
sampled_token_ids_list = torch.cat(decode_sampled_token_ids + prefill_sampled_token_ids).tolist()
28752874
sampled_token_requests = \
28762875
decode_sampled_requests + prefill_sampled_requests
2877-
max_req_index = max(self.input_batch.req_id_to_index.values())
2878-
postprocessed_sampled_token_ids: list[list]
2879-
postprocessed_sampled_token_ids = [[] for _ in range(max_req_index + 1)]
28802876
# NOTE(Chendi): in post-processing, spec_decode might
28812877
# return more than 1 token during decode.
28822878
start_idx = 0
@@ -2928,12 +2924,10 @@ def execute_model(
29282924
req_state.output_token_ids.extend(sampled_ids)
29292925

29302926
# Create output.
2931-
all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids
29322927
# prompt_logprobs_dict: dict[
29332928
# str, Optional[LogprobsTensors]] = self._get_prompt_logprobs_dict(
29342929
# prefill_hidden_states_device, scheduler_output)
29352930
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
2936-
all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids
29372931
logprobs = None
29382932

29392933
model_runner_output = ModelRunnerOutput(

0 commit comments

Comments
 (0)