Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 158 additions & 43 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.sample.logits_processor import build_logitsprocs
from torch.nn.utils.rnn import pad_sequence
from vllm.sampling_params import SamplingParams

if TYPE_CHECKING:
import xgrammar as xgr
Expand Down Expand Up @@ -1253,12 +1254,11 @@ def _get_prompts_and_decodes(
return PromptDecodeInfo(prompt_req_ids, decode_req_ids,
prompt_scheduled_tokens)

def _prepare_sampling(self,
batch_changed: bool,
request_ids: Union[None, list[str]] = None,
pad_to: Optional[int] = None,
logits_reqs=None) -> SamplingMetadata:
# Create the sampling metadata.
def _generate_req_id_output_token_ids_lst(
self,
request_ids: Optional[list[str]] = None,
pad_to: Optional[int] = None,
logits_reqs=None):
req_id_output_token_ids: dict[str, list[int]] = \
{req_id: req.output_token_ids
for req_id, req in self.requests.items()}
Expand All @@ -1278,6 +1278,17 @@ def _prepare_sampling(self,
while len(req_id_output_token_ids_lst) < pad_to:
req_id_output_token_ids_lst.append(
req_id_output_token_ids_lst[0])
return req_id_output_token_ids_lst

def _prepare_sampling(self,
batch_changed: bool,
request_ids: Union[None, list[str]] = None,
pad_to: Optional[int] = None,
logits_reqs=None) -> SamplingMetadata:
# Create the sampling metadata.
req_id_output_token_ids_lst = \
self._generate_req_id_output_token_ids_lst(request_ids, \
pad_to, logits_reqs)
sampling_metadata = self.input_batch.make_selective_sampling_metadata(
req_id_output_token_ids_lst, skip_copy=not batch_changed)
return sampling_metadata
Expand Down Expand Up @@ -2326,6 +2337,21 @@ def apply_grammar_bitmask(
logits.copy_(
logits_cpu.to(self.device, non_blocking=True).to(logits.dtype))

def _run_sampling(
self,
batch_changed: bool,
logits_device: torch.Tensor,
request_ids: Optional[list[str]] = None,
pad_to: Optional[int] = None,
logits_requests=None) -> tuple[torch.Tensor, SamplingMetadata]:
htorch.core.mark_step()
sampling_metadata = self._prepare_sampling(batch_changed, request_ids,
pad_to, logits_requests)
sampler_output = self.sampler(logits=logits_device,
sampling_metadata=sampling_metadata)
htorch.core.mark_step()
return sampler_output, sampling_metadata

def _pool(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -2606,18 +2632,12 @@ def execute_model(
prefill_sampled_requests.extend(logits_requests)
else:
with self.profiler.record_event('internal', "sampler"):
sampling_metadata = self._prepare_sampling(
batch_changed,
req_id,
pad_to=logits_device.shape[0],
logits_reqs=logits_requests)
sampler_output = self.sampler(
logits=logits_device,
sampling_metadata=sampling_metadata)
sampler_output, _sampling_metadata = self._run_sampling(
batch_changed, logits_device, req_id,
logits_device.shape[0], logits_requests)
prefill_sampled_token_ids.append(
sampler_output.sampled_token_ids.flatten())
prefill_sampled_requests.extend(logits_requests)
htorch.core.mark_step()
if self.is_driver_worker and self.profiler.enabled:
# Stop recording 'execute_model_generic' event
self.profiler.end()
Expand Down Expand Up @@ -2657,30 +2677,22 @@ def execute_model(
self.input_batch.req_ids[:num_decodes])
else:
with self.profiler.record_event('internal', "sampler"):
sampling_metadata = self._prepare_sampling(
batch_changed,
pd_info.decode_req_ids,
pad_to=logits_device.shape[0])
##### sampling #####
if decode_data.spec_decode_metadata is None:
sampler_output = self.sampler(
logits=logits_device,
sampling_metadata=sampling_metadata)
##### Sampling Start #####
spec_decode_metadata = decode_data.spec_decode_metadata
sampler_output, sampling_metadata = self._run_sampling(
batch_changed, logits_device
if spec_decode_metadata is None else logits_device[
spec_decode_metadata.bonus_logits_indices],
pd_info.decode_req_ids, logits_device.shape[0])

if spec_decode_metadata is None:
decode_sampled_token_ids.append(
sampler_output.sampled_token_ids.flatten())
else:
# Hanlding spec decode sampling.
spec_decode_metadata = decode_data.spec_decode_metadata
logits = logits_device
bonus_logits = logits[
spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
# Handling spec decode sampling.
bonus_token_ids = \
sampler_output.sampled_token_ids.squeeze()
target_logits = logits[
target_logits = logits_device[
spec_decode_metadata.target_logits_indices]

output_token_ids = self.rejection_sampler(
Expand All @@ -2695,7 +2707,6 @@ def execute_model(
self.input_batch.req_ids[:num_decodes])
##### Sampling End #####

htorch.core.mark_step()
if self.is_driver_worker and self.profiler.enabled:
# Stop recording 'execute_model' event
self.profiler.end()
Expand Down Expand Up @@ -2754,13 +2765,10 @@ def execute_model(
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
sampling_metadata = self._prepare_sampling(batch_changed,
pd_info.prompt_req_ids +
pd_info.decode_req_ids,
pad_to=logits.shape[0])
# sampling_metadata = self.input_batch.sampling_metadata
sampler_output = self.sampler(logits=logits,
sampling_metadata=sampling_metadata)
sampler_output, _sampling_metadata = self._run_sampling(
batch_changed, logits,
pd_info.prompt_req_ids + pd_info.decode_req_ids,
logits.shape[0])
# Deal with the case of incomplete prompt
for i in range(logits.shape[0] - num_decodes):
prefill_sampled_token_ids.append(
Expand Down Expand Up @@ -3032,6 +3040,111 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len, num_blocks):
f"free_mem:{free_mem}")
logger.info(msg)

def warmup_sampler(self):
"""
Warmup the sampler with different temperature, top-p, and top-k values.
"""
# Choose batch sizes for warmup based on bucketing
test_batch_sizes = list(
dict.fromkeys([0, 1] + [
bucket[0] for bucket in self.bucketing_manager.decode_buckets
]))

# Test different sampling configurations
sampling_configs = [
# (temperature, top_p, top_k, batch_changed)
(0.0, 1.0, 0, True), # Greedy sampling
(1.0, 1.0, 0, True), # Random sampling with temp=1.0
(0.7, 0.9, 50, True), # Common creative settings
(0.3, 0.95, 20, True), # Conservative settings
(1.2, 0.8, 100, True), # High temperature settings
(0.8, 0.85, 0, True), # Different top-p sampling
(0.0, 1.0, 0, False), # Greedy sampling
(1.0, 1.0, 0, False), # Random sampling with temp=1.0
(0.7, 0.9, 50, False), # Common creative settings
(0.3, 0.95, 20, False), # Conservative settings
(1.2, 0.8, 100, False), # High temperature settings
(0.8, 0.85, 0, False), # Different top-p sampling
]

logger.info(
"Warming up sampler with batch sizes: %s and following configs:",
test_batch_sizes)
for temp, top_p, top_k, batch_changed in sampling_configs:
logger.info("temp=%s, top_p=%s, top_k=%s, batch_changed=%s", temp,
top_p, top_k, batch_changed)
logger.info("Starting sampler warmup...")

for batch_size in test_batch_sizes:
dummy_hidden_states = torch.randn(batch_size,
self.hidden_size,
dtype=self.dtype,
device=self.device)
dummy_logits = self.model.compute_logits(dummy_hidden_states, None)

# Create dummy requests for this specific configuration
dummy_req_ids = [
f"warmup_req_{batch_size}_{i}"
for i in range(max(1, batch_size))
]

for i, req_id in enumerate(dummy_req_ids):
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=list(range(10)), # Dummy prompt
mm_kwargs=[],
mm_positions=[],
mm_hashes=[],
sampling_params=SamplingParams(),
pooling_params=None,
generator=None,
block_ids=[[0]],
num_computed_tokens=10,
output_token_ids=[],
)
self.input_batch.req_id_to_index[req_id] = i

for temp, top_p, top_k, batch_changed in sampling_configs:
# Add dummy requests to cache with consistent sampling params
for i, req_id in enumerate(dummy_req_ids):
self.requests[req_id].sampling_params = SamplingParams(
temperature=temp,
top_p=top_p,
top_k=top_k,
)

if temp == 0.0: # Greedy sampling
self.input_batch.greedy_reqs.add(req_id)
else: # Random sampling
self.input_batch.random_reqs.add(req_id)

self.input_batch.req_output_token_ids = [
item[1]
for item in self._generate_req_id_output_token_ids_lst(
dummy_req_ids, pad_to=batch_size)
]
self.input_batch.refresh_sampling_metadata()

_sampler_output, _sampling_metadata = self._run_sampling(
batch_changed=batch_changed,
logits_device=dummy_logits,
request_ids=dummy_req_ids,
pad_to=dummy_logits.shape[0])

# Cleanup after sampling
self.input_batch.greedy_reqs = set()
self.input_batch.random_reqs = set()
self.input_batch.req_output_token_ids = []

# Cleanup after batch has been warmed up
self.input_batch.req_id_to_index = {}
self.requests = {}

# Final synchronization to ensure all operations are completed
torch.hpu.synchronize()

logger.info("Sampler warmup completed successfully")

def warmup_graphs(self,
buckets,
is_prompt,
Expand Down Expand Up @@ -3082,7 +3195,6 @@ def _add_dummy_request(self,
scheduled_tokens,
is_prompt,
block_id=0):
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import NewRequestData
num_blocks = round_up(total_tokens, self.block_size) // self.block_size
prompt_token_ids = list(range(total_tokens))
Expand Down Expand Up @@ -3341,6 +3453,9 @@ def warmup_model(self) -> None:
assert self.mem_margin is not None, \
("HabanaWorker.determine_num_available_blocks needs "
"to be called before warming up the model.")

self.warmup_sampler()

# TODO(kzawora): align_workers
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
self.warmup_graphs(
Expand Down
4 changes: 3 additions & 1 deletion vllm_gaudi/v1/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
self.compile_or_warm_up_model()

def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
# Don't run the warmup if in eager or if the model is already warmed up
if not self.model_config.enforce_eager \
and not self.model_runner.graphed_buckets:
self.model_runner.warmup_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
Expand Down