Skip to content

Commit 0c15a87

Browse files
committed
Fix accurady issue
Signed-off-by: Tianmu Li <[email protected]>
1 parent 48fec5e commit 0c15a87

File tree

1 file changed

+39
-37
lines changed

1 file changed

+39
-37
lines changed

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,12 @@ def get_output(self) -> ModelRunnerOutput:
130130
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
131131
del self._sampled_token_ids
132132
for i in self._invalid_req_indices:
133-
if i< len(valid_sampled_token_ids):
133+
if i < len(valid_sampled_token_ids):
134134
valid_sampled_token_ids[i].clear()
135135

136136
output = self._model_runner_output
137-
output.sampled_token_ids = valid_sampled_token_ids
137+
output.sampled_token_ids[:len(valid_sampled_token_ids
138+
)] = valid_sampled_token_ids
138139
return output
139140

140141

@@ -2034,6 +2035,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
20342035
if self.input_batch.prev_sampled_token_ids is None:
20352036
return
20362037

2038+
prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
20372039
# Async scheduling case, where some decode requests from the previous
20382040
# iteration won't have entries in input_ids_cpu and need to be copied
20392041
# on the GPU from prev_sampled_token_ids.
@@ -2044,10 +2046,11 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
20442046
indices_match = True
20452047
max_flattened_index = -1
20462048
for req_id, cur_index in self.input_batch.req_id_to_index.items():
2047-
# if req_id in self.input_batch.prev_sampled_token_ids_invalid_indices:
2048-
# # This request was in the previous batch but its
2049-
# # prev_sampled_token_ids is invalid
2050-
# continue
2049+
if req_id in self.input_batch.\
2050+
prev_sampled_token_ids_invalid_indices:
2051+
# This request was in the previous batch but its
2052+
# prev_sampled_token_ids is invalid
2053+
continue
20512054
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
20522055
prev_common_req_indices.append(prev_index)
20532056
# We need to compute the flattened input_ids index of the
@@ -2061,33 +2064,30 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
20612064
# No requests in common with the previous iteration
20622065
# So input_ids_cpu will have all the input ids.
20632066
return
2064-
if indices_match and max_flattened_index == (
2065-
num_commmon_tokens - 1):
2067+
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
20662068
# Common-case optimization: the batch is unchanged
20672069
# and no reordering happened.
20682070
# The indices are both the same permutation of 0..N-1
20692071
self.input_ids_cpu[:len(flattened_indices)].copy_(
2070-
self.input_batch.
20712072
prev_sampled_token_ids[:len(flattened_indices)])
20722073
return
20732074

20742075
# Upload the index tensors asynchronously
20752076
# so the scatter can be non-blocking
20762077
input_ids_index_tensor = torch.tensor(flattened_indices,
2077-
dtype=torch.int64,
2078-
device="cpu")
2079-
prev_common_req_indices_tensor = torch.tensor(
2080-
prev_common_req_indices,
2081-
dtype=torch.int64,
2082-
device="cpu")
2083-
src_tensor = self.input_batch.prev_sampled_token_ids
2084-
# logger.info(f"Scattering prev_common_req_indices_tensor: {prev_common_req_indices_tensor} from src_tensor: {len(src_tensor)} "
2085-
# f"to input_ids_index_tensor: {input_ids_index_tensor}")
2078+
dtype=torch.int64,
2079+
device="cpu")
2080+
if prev_sampled_token_ids.size(0) <= len(prev_common_req_indices):
2081+
prev_common_req_indices = prev_common_req_indices[:
2082+
prev_sampled_token_ids
2083+
.size(0)]
2084+
prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices,
2085+
dtype=torch.int64,
2086+
device="cpu")
20862087
self.input_ids_cpu.scatter_(
20872088
dim=0,
20882089
index=input_ids_index_tensor,
2089-
src=self.input_batch.
2090-
prev_sampled_token_ids[prev_common_req_indices_tensor])
2090+
src=prev_sampled_token_ids[prev_common_req_indices_tensor])
20912091

20922092
def _prepare_inputs(
20932093
self,
@@ -2118,7 +2118,7 @@ def _prepare_inputs(
21182118
0,
21192119
torch.from_numpy(token_indices),
21202120
out=self.input_ids_cpu[:total_num_scheduled_tokens])
2121-
# Copy the tensors to the GPU.
2121+
# Copy the tensors for async scheduling
21222122
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
21232123
###############################################
21242124

@@ -2662,15 +2662,17 @@ def execute_model(
26622662
# If logits_indices is smaller than req_id,
26632663
# add the last token position
26642664
if logits_indices.shape[0] < len(req_id):
2665-
if structured_output or self.use_async_scheduling:
2666-
logits_append = torch.tensor([torch.sum(prompt_len) - 1],
2667-
device=token_ids.device,
2668-
dtype=torch.int32)
2669-
logits_indices = torch.cat([logits_indices, logits_append])
2665+
if structured_output:
2666+
logits_append = torch.tensor(
2667+
[torch.sum(prompt_len) - 1],
2668+
device=token_ids.device,
2669+
dtype=torch.int32)
2670+
logits_indices = torch.cat(
2671+
[logits_indices, logits_append])
26702672
elif self.use_async_scheduling:
26712673
# Discard partial prefill logits for async scheduling
26722674
# Depends on 1 decode token/batch
2673-
invalid_req_indices.append(num_decodes+idx)
2675+
invalid_req_indices.append(num_decodes + idx)
26742676
htorch.core.mark_step()
26752677
_, sample_hidden_states, logits_device = \
26762678
self._execute_model_generic(
@@ -2848,6 +2850,9 @@ def execute_model(
28482850
# For async scheduling: keep tokens on HPU and avoid CPU sync
28492851
# Concatenate decode and prefill tokens on HPU
28502852
if decode_sampled_token_ids or prefill_sampled_token_ids:
2853+
decode_sampled_token_ids = [
2854+
tensor[:num_decodes] for tensor in decode_sampled_token_ids
2855+
]
28512856
sampled_token_ids = torch.cat(decode_sampled_token_ids +
28522857
prefill_sampled_token_ids).view(
28532858
-1, 1)
@@ -2865,13 +2870,10 @@ def execute_model(
28652870
if self.use_async_scheduling:
28662871
self.input_batch.prev_sampled_token_ids = \
28672872
sampled_token_ids.flatten().to("cpu", non_blocking=True)
2868-
28692873
# self.input_batch.prev_sampled_token_ids_invalid_indices
28702874
invalid_req_indices_set = set(invalid_req_indices)
28712875
self.input_batch.prev_sampled_token_ids_invalid_indices = \
28722876
invalid_req_indices_set
2873-
# logger.info(f"set: {invalid_req_indices_set}, "
2874-
# f"self.input_batch.req_ids: {self.input_batch.req_ids}, ")
28752877
self.input_batch.prev_req_id_to_index = {
28762878
req_id: i
28772879
for i, req_id in enumerate(self.input_batch.req_ids)
@@ -2880,9 +2882,10 @@ def execute_model(
28802882

28812883
# For the output, create placeholder sampled_token_ids
28822884
# (will be filled during serialization)
2883-
2884-
postprocessed_sampled_token_ids = [[] for _ in range(num_reqs)]
2885-
2885+
max_req_index = max(self.input_batch.req_id_to_index.values())
2886+
postprocessed_sampled_token_ids = [[]
2887+
for _ in range(max_req_index +
2888+
1)]
28862889
else:
28872890
# From this point onward, all operations are done on CPU.
28882891
# We already have tokens. Let's copy the data to
@@ -2926,8 +2929,6 @@ def execute_model(
29262929
# the sampled tokens back, because there's no direct communication
29272930
# between the first-stage worker and the last-stage worker.
29282931
for req_idx, sampled_ids in enumerate(postprocessed_sampled_token_ids[:num_reqs]):
2929-
# if self.use_async_scheduling:
2930-
# sampled_ids = [-1] # placeholder
29312932
if not sampled_ids:
29322933
continue
29332934

@@ -2958,7 +2959,7 @@ def execute_model(
29582959
logprobs = None
29592960

29602961
model_runner_output = ModelRunnerOutput(
2961-
req_ids=req_ids_output_copy, # CHECK
2962+
req_ids=req_ids_output_copy, # CHECK
29622963
req_id_to_index=req_id_to_index_output_copy,
29632964
sampled_token_ids=postprocessed_sampled_token_ids,
29642965
logprobs=logprobs,
@@ -2970,7 +2971,8 @@ def execute_model(
29702971
return AsyncHPUModelRunnerOutput(
29712972
model_runner_output=model_runner_output,
29722973
sampled_token_ids=sampled_token_ids,
2973-
invalid_req_indices=invalid_req_indices,)
2974+
invalid_req_indices=[],
2975+
)
29742976
return model_runner_output
29752977

29762978
def load_model(self) -> None:

0 commit comments

Comments
 (0)