Skip to content

Commit 3032bdd

Browse files
committed
pre-commit fix
Signed-off-by: Tianmu Li <[email protected]>
1 parent 0c15a87 commit 3032bdd

File tree

1 file changed

+25
-47
lines changed

1 file changed

+25
-47
lines changed

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 25 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
from vllm_gaudi.utils import (HPUCompileConfig, is_fake_hpu, async_h2d_copy)
4848
from vllm_gaudi.v1.attention.backends.hpu_attn import HPUAttentionMetadataV1
4949
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec)
50-
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds,
51-
LogprobsTensors, ModelRunnerOutput)
50+
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, LogprobsTensors,
51+
ModelRunnerOutput)
5252
from vllm.v1.sample.metadata import SamplingMetadata
5353
from vllm.v1.worker.utils import bind_kv_cache
5454
from vllm_gaudi.v1.worker.hpu_input_batch import InputBatch
@@ -113,8 +113,7 @@ def __init__(
113113
self._sampled_token_ids = sampled_token_ids
114114

115115
# TODO: Change to non_blocking once it is working
116-
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
117-
'cpu', non_blocking=False)
116+
self._sampled_token_ids_cpu = self._sampled_token_ids.to('cpu', non_blocking=False)
118117

119118
def get_output(self) -> ModelRunnerOutput:
120119
"""Copy the device tensors to the host and return a ModelRunnerOutput.
@@ -134,8 +133,7 @@ def get_output(self) -> ModelRunnerOutput:
134133
valid_sampled_token_ids[i].clear()
135134

136135
output = self._model_runner_output
137-
output.sampled_token_ids[:len(valid_sampled_token_ids
138-
)] = valid_sampled_token_ids
136+
output.sampled_token_ids[:len(valid_sampled_token_ids)] = valid_sampled_token_ids
139137
return output
140138

141139

@@ -2024,8 +2022,8 @@ def _prepare_unified_decode_inputs(self, num_decodes, num_scheduled_tokens) -> D
20242022
logits_indices=logits_indices_t,
20252023
attn_metadata=attn_metadata,
20262024
)
2027-
def _prepare_input_ids(self, total_num_scheduled_tokens: int,
2028-
cu_num_tokens: np.ndarray) -> None:
2025+
2026+
def _prepare_input_ids(self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray) -> None:
20292027
"""Prepare the input IDs for the current batch.
20302028
20312029
Carefully handles the `prev_sampled_token_ids` which can be cached
@@ -2046,8 +2044,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
20462044
indices_match = True
20472045
max_flattened_index = -1
20482046
for req_id, cur_index in self.input_batch.req_id_to_index.items():
2049-
if req_id in self.input_batch.\
2050-
prev_sampled_token_ids_invalid_indices:
2047+
if req_id in self.input_batch.prev_sampled_token_ids_invalid_indices:
20512048
# This request was in the previous batch but its
20522049
# prev_sampled_token_ids is invalid
20532050
continue
@@ -2068,26 +2065,18 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
20682065
# Common-case optimization: the batch is unchanged
20692066
# and no reordering happened.
20702067
# The indices are both the same permutation of 0..N-1
2071-
self.input_ids_cpu[:len(flattened_indices)].copy_(
2072-
prev_sampled_token_ids[:len(flattened_indices)])
2068+
self.input_ids_cpu[:len(flattened_indices)].copy_(prev_sampled_token_ids[:len(flattened_indices)])
20732069
return
20742070

20752071
# Upload the index tensors asynchronously
20762072
# so the scatter can be non-blocking
2077-
input_ids_index_tensor = torch.tensor(flattened_indices,
2078-
dtype=torch.int64,
2079-
device="cpu")
2073+
input_ids_index_tensor = torch.tensor(flattened_indices, dtype=torch.int64, device="cpu")
20802074
if prev_sampled_token_ids.size(0) <= len(prev_common_req_indices):
2081-
prev_common_req_indices = prev_common_req_indices[:
2082-
prev_sampled_token_ids
2083-
.size(0)]
2084-
prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices,
2085-
dtype=torch.int64,
2086-
device="cpu")
2087-
self.input_ids_cpu.scatter_(
2088-
dim=0,
2089-
index=input_ids_index_tensor,
2090-
src=prev_sampled_token_ids[prev_common_req_indices_tensor])
2075+
prev_common_req_indices = prev_common_req_indices[:prev_sampled_token_ids.size(0)]
2076+
prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices, dtype=torch.int64, device="cpu")
2077+
self.input_ids_cpu.scatter_(dim=0,
2078+
index=input_ids_index_tensor,
2079+
src=prev_sampled_token_ids[prev_common_req_indices_tensor])
20912080

20922081
def _prepare_inputs(
20932082
self,
@@ -2112,8 +2101,7 @@ def _prepare_inputs(
21122101
cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
21132102
np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np)
21142103
token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1])
2115-
cu_num_tokens, arange = self._get_cumsum_and_arange(
2116-
num_scheduled_tokens)
2104+
cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
21172105
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
21182106
0,
21192107
torch.from_numpy(token_indices),
@@ -2663,12 +2651,10 @@ def execute_model(
26632651
# add the last token position
26642652
if logits_indices.shape[0] < len(req_id):
26652653
if structured_output:
2666-
logits_append = torch.tensor(
2667-
[torch.sum(prompt_len) - 1],
2668-
device=token_ids.device,
2669-
dtype=torch.int32)
2670-
logits_indices = torch.cat(
2671-
[logits_indices, logits_append])
2654+
logits_append = torch.tensor([torch.sum(prompt_len) - 1],
2655+
device=token_ids.device,
2656+
dtype=torch.int32)
2657+
logits_indices = torch.cat([logits_indices, logits_append])
26722658
elif self.use_async_scheduling:
26732659
# Discard partial prefill logits for async scheduling
26742660
# Depends on 1 decode token/batch
@@ -2850,16 +2836,10 @@ def execute_model(
28502836
# For async scheduling: keep tokens on HPU and avoid CPU sync
28512837
# Concatenate decode and prefill tokens on HPU
28522838
if decode_sampled_token_ids or prefill_sampled_token_ids:
2853-
decode_sampled_token_ids = [
2854-
tensor[:num_decodes] for tensor in decode_sampled_token_ids
2855-
]
2856-
sampled_token_ids = torch.cat(decode_sampled_token_ids +
2857-
prefill_sampled_token_ids).view(
2858-
-1, 1)
2839+
decode_sampled_token_ids = [tensor[:num_decodes] for tensor in decode_sampled_token_ids]
2840+
sampled_token_ids = torch.cat(decode_sampled_token_ids + prefill_sampled_token_ids).view(-1, 1)
28592841
else:
2860-
sampled_token_ids = torch.empty((0, 1),
2861-
dtype=torch.int32,
2862-
device=self.device)
2842+
sampled_token_ids = torch.empty((0, 1), dtype=torch.int32, device=self.device)
28632843

28642844
# Copy some objects so they don't get modified after returning.
28652845
# This is important when using async scheduling.
@@ -2868,6 +2848,7 @@ def execute_model(
28682848
self.input_batch.req_id_to_index.copy()
28692849

28702850
if self.use_async_scheduling:
2851+
assert not self.speculative_config, "Speculative decoding not supported with async scheduling"
28712852
self.input_batch.prev_sampled_token_ids = \
28722853
sampled_token_ids.flatten().to("cpu", non_blocking=True)
28732854
# self.input_batch.prev_sampled_token_ids_invalid_indices
@@ -2876,16 +2857,13 @@ def execute_model(
28762857
invalid_req_indices_set
28772858
self.input_batch.prev_req_id_to_index = {
28782859
req_id: i
2879-
for i, req_id in enumerate(self.input_batch.req_ids)
2880-
if i not in invalid_req_indices_set
2860+
for i, req_id in enumerate(self.input_batch.req_ids) if i not in invalid_req_indices_set
28812861
}
28822862

28832863
# For the output, create placeholder sampled_token_ids
28842864
# (will be filled during serialization)
28852865
max_req_index = max(self.input_batch.req_id_to_index.values())
2886-
postprocessed_sampled_token_ids = [[]
2887-
for _ in range(max_req_index +
2888-
1)]
2866+
postprocessed_sampled_token_ids = [[] for _ in range(max_req_index + 1)]
28892867
else:
28902868
# From this point onward, all operations are done on CPU.
28912869
# We already have tokens. Let's copy the data to

0 commit comments

Comments
 (0)