From b4b908a6fee5832cbf842fbbcff25719f3057542 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Mon, 11 Aug 2025 09:47:34 +0300 Subject: [PATCH 01/13] Support Data Parallel - enable profile run Signed-off-by: Wuxun Zhang --- examples/data_parallel.py | 254 ++++++++++++++++++ .../device_communicators/hpu_communicator.py | 45 ++++ vllm_gaudi/platform.py | 1 + vllm_gaudi/v1/worker/hpu_model_runner.py | 61 ++++- vllm_gaudi/v1/worker/hpu_worker.py | 40 ++- 5 files changed, 392 insertions(+), 9 deletions(-) create mode 100644 examples/data_parallel.py diff --git a/examples/data_parallel.py b/examples/data_parallel.py new file mode 100644 index 00000000..ea4f659a --- /dev/null +++ b/examples/data_parallel.py @@ -0,0 +1,254 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Usage: +Single node: + python examples/offline_inference/data_parallel.py \ + --model="ibm-research/PowerMoE-3b" \ + --dp-size=2 \ + --tp-size=2 + +Multi-node: + Node 0 (assume the node has ip of 10.99.48.128): + python examples/offline_inference/data_parallel.py \ + --model="ibm-research/PowerMoE-3b" \ + --dp-size=2 \ + --tp-size=2 \ + --node-size=2 \ + --node-rank=0 \ + --master-addr=10.99.48.128 \ + --master-port=13345 + Node 1: + python examples/offline_inference/data_parallel.py \ + --model="ibm-research/PowerMoE-3b" \ + --dp-size=2 \ + --tp-size=2 \ + --node-size=2 \ + --node-rank=1 \ + --master-addr=10.99.48.128 \ + --master-port=13345 +""" + +import os +from time import sleep +import torch + +from vllm import LLM, SamplingParams +from vllm.utils import get_open_port + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser(description="Data Parallel Inference") + parser.add_argument( + "--model", + type=str, + default="ibm-research/PowerMoE-3b", + help="Model name or path", + ) + parser.add_argument( + "--dp-size", type=int, default=2, help="Data parallel size" + ) + parser.add_argument( + "--tp-size", type=int, default=2, help="Tensor parallel size" + ) + parser.add_argument( + "--node-size", type=int, default=1, help="Total number of nodes" + ) + parser.add_argument( + "--node-rank", type=int, default=0, help="Rank of the current node" + ) + parser.add_argument( + "--master-addr", type=str, default="", help="Master node IP address" + ) + parser.add_argument( + "--master-port", type=int, default=0, help="Master node port" + ) + parser.add_argument( + "--enforce-eager", + action="store_true", + help="Enforce eager mode execution.", + ) + parser.add_argument( + "--trust-remote-code", action="store_true", help="Trust remote code." + ) + parser.add_argument( + "--max-num-seqs", + type=int, + default=64, + help=( + "Maximum number of sequences to be processed in a single iteration." + ), + ) + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=0.8, + help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), + ) + parser.add_argument( + "--random-input", + action="store_true", + help="Use random generated input tokens.", + ) + return parser.parse_args() + + +def generate_random_token_ids(repeat=1) -> list[int]: + """ + For testing different seuquence length in data parallel scenario + """ + candidate_lens = [130, 560] + prompts = [] + for num_tokens in candidate_lens: + tokens = torch.randint( + low=0, high=10000, size=(num_tokens,), dtype=torch.int32 + ) + [prompts.append(tokens.tolist()) for _ in range(repeat)] + return prompts + + +def main( + model, + dp_size, + local_dp_rank, + global_dp_rank, + dp_master_ip, + dp_master_port, + GPUs_per_dp_rank, + enforce_eager, + trust_remote_code, + max_num_seqs, + gpu_memory_utilization, +): + os.environ["VLLM_DP_RANK"] = str(global_dp_rank) + os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) + os.environ["VLLM_DP_SIZE"] = str(dp_size) + os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip + os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) + + # CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the + # engine processes. + + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 40 + + # generate prompts with different length to demonstrate DP aware padding. + if args.random_input: + prompts = generate_random_token_ids(40) + + # with DP, each rank should process different prompts. + # usually all the DP ranks process a full dataset, + # and each rank processes a different part of the dataset. + floor = len(prompts) // dp_size + remainder = len(prompts) % dp_size + + # Distribute prompts into even groups. + def start(rank): + return rank * floor + min(rank, remainder) + + prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)] + if len(prompts) == 0: + # if any rank has no prompts to process, + # we need to set a placeholder prompt + prompts = ["Placeholder"] + print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts") + # Create a sampling params object. + # since we are doing data parallel, every rank can have different + # sampling params. here we set different max_tokens for different + # ranks for demonstration. + sampling_params = SamplingParams( + temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2] + ) + + # Create an LLM. + llm = LLM( + model=model, + tensor_parallel_size=GPUs_per_dp_rank, + enforce_eager=enforce_eager, + enable_expert_parallel=True, + trust_remote_code=trust_remote_code, + max_num_seqs=max_num_seqs, + gpu_memory_utilization=gpu_memory_utilization, + ) + if not args.random_input: + outputs = llm.generate(prompts, sampling_params) + else: + outputs = llm.generate(None, sampling_params, prompts) + # Print the outputs. + for i, output in enumerate(outputs): + if i >= 5: + # print only 5 outputs + break + prompt = output.prompt + generated_text = output.outputs[0].text + print( + f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}" + ) + + # Give engines time to pause their processing loops before exiting. + sleep(1) + + +if __name__ == "__main__": + args = parse_args() + + dp_size = args.dp_size + tp_size = args.tp_size + node_size = args.node_size + node_rank = args.node_rank + + if node_size == 1: + dp_master_ip = "127.0.0.1" + dp_master_port = get_open_port() + else: + dp_master_ip = args.master_addr + dp_master_port = args.master_port + + assert dp_size % node_size == 0, "dp_size should be divisible by node_size" + dp_per_node = dp_size // node_size + + from multiprocessing import Process + + procs = [] + for local_dp_rank, global_dp_rank in enumerate( + range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node) + ): + proc = Process( + target=main, + args=( + args.model, + dp_size, + local_dp_rank, + global_dp_rank, + dp_master_ip, + dp_master_port, + tp_size, + args.enforce_eager, + args.trust_remote_code, + args.max_num_seqs, + args.gpu_memory_utilization, + ), + ) + proc.start() + procs.append(proc) + exit_code = 0 + for proc in procs: + proc.join(timeout=300) + if proc.exitcode is None: + print( + f"Killing process {proc.pid} that didn't stop within 5 minutes." + ) + proc.kill() + exit_code = 1 + elif proc.exitcode: + exit_code = proc.exitcode + + exit(exit_code) diff --git a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py index 6bdaa43b..e447c623 100644 --- a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py +++ b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py @@ -5,10 +5,30 @@ from vllm.distributed.device_communicators.base_device_communicator \ import DeviceCommunicatorBase +from vllm.distributed.parallel_state import get_dp_group +from vllm.forward_context import get_forward_context import habana_frameworks.torch as htorch # noqa: F401 +def naive_multicast(x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor) -> torch.Tensor: + assert x.dim() == 2, "Input tensor must be 2D" + dp_rank = get_dp_group().rank_in_group + dp_world_size = get_dp_group().world_size + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[dp_rank - 1] + end = cu_tokens_across_dp_cpu[dp_rank] + buffer[start:end, :].copy_(x) + for idx in range(dp_world_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + get_dp_group().broadcast(buffer[start:end, :], idx) + return buffer + + class HpuCommunicator(DeviceCommunicatorBase): def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: @@ -41,3 +61,28 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: input_size[dim], ) + input_size[dim + 1:]) return output_tensor + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + all-gather based dispatch for HPUCommunicator. + """ + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + hidden_states_across_dp = naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + router_logits_across_dp = naive_multicast(router_logits, + cu_tokens_across_dp_cpu) + return hidden_states_across_dp, router_logits_across_dp + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + dp_rank = get_dp_group().rank_in_group + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[dp_rank - 1] + end = cu_tokens_across_dp_cpu[dp_rank] + + all_hidden_states = get_dp_group().all_reduce(hidden_states) + hidden_states = all_hidden_states[start:end, :] + return hidden_states diff --git a/vllm_gaudi/platform.py b/vllm_gaudi/platform.py index 61aa6380..e5d4f756 100644 --- a/vllm_gaudi/platform.py +++ b/vllm_gaudi/platform.py @@ -31,6 +31,7 @@ class HpuPlatform(Platform): supported_quantization: list[str] = [ "compressed-tensors", "fp8", "inc", "awq_hpu", "gptq_hpu" ] + simple_compile_backend = "hpu_backend" @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index d2efc73b..65e343a2 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -30,7 +30,7 @@ from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend from vllm.config import (VllmConfig, update_config) -from vllm.forward_context import set_forward_context +from vllm.forward_context import set_forward_context, DPMetadata from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import get_sampler @@ -58,6 +58,7 @@ from vllm_gaudi.v1.worker.hpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.distributed.parallel_state import get_pp_group + from vllm.model_executor.models.interfaces import supports_transcription from vllm.model_executor.models.interfaces_base import ( is_pooling_model, is_text_generation_model) @@ -423,7 +424,8 @@ def forward(self, *args, **kwargs): if model_mm_kwargs is not None: kwargs.update(model_mm_kwargs) - with set_forward_context(attn_meta, self.vllm_config): + num_input_tokens = input_ids.size(0) * input_ids.size(1) + with set_forward_context(attn_meta, self.vllm_config, num_tokens=num_input_tokens): hidden_states = self.model(*args, **kwargs) if self._rotary_prepare_cos_sin is not None: self._reset_rotary_cos_sin() @@ -1401,6 +1403,17 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, merge_contents(all_batch_contents[-1], new_batch_contents) else: all_batch_contents.append(new_batch_contents) + + if (len(all_batch_contents[0].req_ids) > 0): + num_prefill_batches = len(all_batch_contents) + else: + # no real prefill batches + num_prefill_batches = 0 + + num_pad = self.get_dp_padding(num_prefill_batches) + if num_pad > 0: + for _ in range(num_pad): + all_batch_contents.append(BatchContents()) return all_batch_contents def _make_attn_bias(self, context_groups, token_groups): @@ -1469,6 +1482,11 @@ def _form_prefill_batch(self, contents): target_bs, target_seq, target_blocks = self._get_prompt_bucketing_fn()( query_lens, num_context_blocks) + # dp aware padding + target_bs += self.get_dp_padding(target_bs) + target_seq += self.get_dp_padding(target_seq) + target_blocks += self.get_dp_padding(target_blocks) + # NOTE: If model does not support multimodal inputs, we pad here. # For models with multimodal support, we may want to get embeddings # for the valid tokens before padding. @@ -1564,7 +1582,6 @@ def _form_prefill_batch(self, contents): def _prepare_prefill_inputs( self, num_prefills, num_decodes, num_scheduled_tokens: list[int]) -> PrefillInputData: - all_batch_contents = self._extract_prefill_batch_contents( num_prefills, num_decodes, num_scheduled_tokens) all_batches = [ @@ -1602,6 +1619,9 @@ def _prepare_decode_inputs(self, padded_batch_size = self.bucketing_manager.find_decode_bucket( num_decodes, sum(num_blocks))[0] + # dp aware padding + padded_batch_size += self.get_dp_padding(padded_batch_size) + num_tokens_per_req = num_scheduled_tokens[:num_decodes] num_tokens = max(num_tokens_per_req) total_num_scheduled_tokens = sum(num_tokens_per_req) @@ -1965,6 +1985,30 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata, if not seen and not warmup_mode: logger.warning("Configuration: %s was not warmed-up!", cfg) + def get_dp_padding(self, + num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + dp_size = self.vllm_config.parallel_config.data_parallel_size + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + + # For DP: Don't pad when setting enforce_eager. + # This lets us set enforce_eager on the prefiller in a P/D setup and + # still use CUDA graphs (enabled by this padding) on the decoder. + # + # TODO(tms) : There are many cases where padding is enabled for + # prefills, causing unnecessary and excessive padding of activations. + + # skip padding for non PD disagg case to avoid padding on prefill batch + # size and decode batch size + if dp_size == 1 or self.vllm_config.model_config.enforce_eager or ( + self.vllm_config.kv_transfer_config is None + or self.vllm_config.kv_transfer_config.kv_connector is None): + return 0 + + num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + num_tokens, dp_size, dp_rank) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() + return max_tokens_across_dp_cpu - num_tokens + def _execute_model_generic(self, token_ids, position_ids, @@ -2282,6 +2326,8 @@ def execute_model( ######################### PREFILLS ######################### if num_prefills > 0: + # Wuxun: merged prefill forward if enabled + # 2D bucketing or merged prefill bucketing htorch.core.mark_step() for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata, logits_indices, @@ -3051,7 +3097,6 @@ def __del__(self): @torch.inference_mode() def profile_run(self) -> None: - return """Profile to measure peak memory during forward pass.""" # use an empty tensor instead of `None`` to force Dynamo to pass @@ -3070,6 +3115,14 @@ def profile_run(self) -> None: self._execute_dummy_scenario( (self.max_prefill_batch_size, max_seq_len, 0), None) + def _dummy_run(self, max_num_batched_tokens: int) -> None: + assert max_num_batched_tokens == 1 + prompt_cfg = None + decode_cfg = 1, 1 + # add dummy decode run + self._execute_dummy_scenario(prompt_cfg, decode_cfg) + return + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. diff --git a/vllm_gaudi/v1/worker/hpu_worker.py b/vllm_gaudi/v1/worker/hpu_worker.py index cdcefee0..bd28a8fc 100644 --- a/vllm_gaudi/v1/worker/hpu_worker.py +++ b/vllm_gaudi/v1/worker/hpu_worker.py @@ -178,14 +178,16 @@ def determine_available_memory(self) -> int: single_kv_block_size_bytes = 0 for layer_name, layer_spec in kv_cache_spec.items(): if isinstance(layer_spec, FullAttentionSpec): - dtype = layer_spec.dtype + # dtype = layer_spec.dtype # Use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value ``None``. - hpu_k_cache = torch.tensor([], dtype=dtype, device='hpu') - hpu_v_cache = torch.tensor([], dtype=dtype, device='hpu') + # hpu_k_cache = torch.tensor([], dtype=dtype, device='hpu') + # hpu_v_cache = torch.tensor([], dtype=dtype, device='hpu') - kv_caches[layer_name] = (hpu_k_cache, hpu_v_cache) + # kv_caches[layer_name] = (hpu_k_cache, hpu_v_cache) + # avoid issue of reading kv cache during profiling + kv_caches[layer_name] = None single_kv_block_size_bytes += layer_spec.page_size_bytes @@ -303,6 +305,9 @@ def profile(self, is_start: bool = True): else: self.profiler.stop() + def execute_dummy_batch(self) -> None: + self.model_runner._dummy_run(1) + def init_worker_distributed_environment( parallel_config: ParallelConfig, @@ -318,9 +323,34 @@ def init_worker_distributed_environment( backend='hccl') ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + expected_size = parallel_config.world_size *\ + parallel_config.data_parallel_size + if torch_world_size != expected_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch world " + "size does not match parallel_config.world_size * " + "parallel_config.data_parallel_size " + f"({torch_world_size} vs. {expected_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + backend = 'hccl' + torch.distributed.init_process_group( + backend=backend, + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) + dummy_tensor_hpu = torch.ones(1).to('hpu') torch.distributed.all_reduce(dummy_tensor_hpu) - assert dummy_tensor_hpu.item() == parallel_config.world_size + assert dummy_tensor_hpu.item( + ) == parallel_config.world_size * parallel_config.data_parallel_size ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) From 90532c8d6ea24ec179ac49fd3c495af3dbbb2b04 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 20 Aug 2025 12:04:27 +0300 Subject: [PATCH 02/13] fix Signed-off-by: Wuxun Zhang --- .../distributed/device_communicators/hpu_communicator.py | 3 --- vllm_gaudi/v1/worker/hpu_model_runner.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py index e447c623..b0427cbb 100644 --- a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py +++ b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py @@ -65,9 +65,6 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - all-gather based dispatch for HPUCommunicator. - """ cu_tokens_across_dp_cpu = get_forward_context( ).dp_metadata.cu_tokens_across_dp_cpu hidden_states_across_dp = naive_multicast(hidden_states, diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 65e343a2..4e884f79 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2326,8 +2326,6 @@ def execute_model( ######################### PREFILLS ######################### if num_prefills > 0: - # Wuxun: merged prefill forward if enabled - # 2D bucketing or merged prefill bucketing htorch.core.mark_step() for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata, logits_indices, From cdcc3cc3e8118973d537a5859be40fbc05174a2a Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 20 Aug 2025 18:28:08 +0300 Subject: [PATCH 03/13] fix dummy run Signed-off-by: Wuxun Zhang --- vllm_gaudi/v1/worker/hpu_model_runner.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 4e884f79..0694d960 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -425,7 +425,9 @@ def forward(self, *args, **kwargs): kwargs.update(model_mm_kwargs) num_input_tokens = input_ids.size(0) * input_ids.size(1) - with set_forward_context(attn_meta, self.vllm_config, num_tokens=num_input_tokens): + with set_forward_context(attn_meta, + self.vllm_config, + num_tokens=num_input_tokens): hidden_states = self.model(*args, **kwargs) if self._rotary_prepare_cos_sin is not None: self._reset_rotary_cos_sin() @@ -3110,13 +3112,14 @@ def profile_run(self) -> None: max_seq_len = math.ceil( (self.max_num_tokens // self.max_prefill_batch_size) / self.block_size) * self.block_size + max_seq_len = min(max_seq_len, self.max_model_len) self._execute_dummy_scenario( (self.max_prefill_batch_size, max_seq_len, 0), None) def _dummy_run(self, max_num_batched_tokens: int) -> None: assert max_num_batched_tokens == 1 prompt_cfg = None - decode_cfg = 1, 1 + decode_cfg = 1, 1, 1 # add dummy decode run self._execute_dummy_scenario(prompt_cfg, decode_cfg) return From 5620188c2f0174d44981f66fdeb596836f4ce36a Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Sun, 24 Aug 2025 15:59:35 +0300 Subject: [PATCH 04/13] fix lazy hang Signed-off-by: Wuxun Zhang --- .../device_communicators/hpu_communicator.py | 71 +++++--- vllm_gaudi/v1/worker/hpu_model_runner.py | 164 ++++++++++++------ vllm_gaudi/v1/worker/hpu_worker.py | 10 +- 3 files changed, 163 insertions(+), 82 deletions(-) diff --git a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py index b0427cbb..1d482306 100644 --- a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py +++ b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py @@ -1,35 +1,51 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional import torch import torch.distributed as dist +from torch.distributed import ProcessGroup from vllm.distributed.device_communicators.base_device_communicator \ import DeviceCommunicatorBase -from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context +from vllm.distributed.parallel_state import get_dp_group import habana_frameworks.torch as htorch # noqa: F401 -def naive_multicast(x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor) -> torch.Tensor: - assert x.dim() == 2, "Input tensor must be 2D" - dp_rank = get_dp_group().rank_in_group - dp_world_size = get_dp_group().world_size - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), - device=x.device, - dtype=x.dtype) - start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[dp_rank - 1] - end = cu_tokens_across_dp_cpu[dp_rank] - buffer[start:end, :].copy_(x) - for idx in range(dp_world_size): - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] - end = cu_tokens_across_dp_cpu[idx] - get_dp_group().broadcast(buffer[start:end, :], idx) - return buffer +class HpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + self.dp_group = None + self.dp_rank = 0 + self.dp_world_size = 1 + # assume EP is enabled along with DP + if "ep" in unique_name: + self.dp_group = get_dp_group() + self.dp_rank = self.dp_group.rank_in_group + self.dp_world_size = self.dp_group.world_size -class HpuCommunicator(DeviceCommunicatorBase): + def naive_multicast(self, x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor) -> torch.Tensor: + assert x.dim() == 2, "Input tensor must be 2D" + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + buffer[start:end, :].copy_(x) + for idx in range(self.dp_world_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + self.dp_group.broadcast(buffer[start:end, :], idx) + return buffer def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge @@ -67,19 +83,22 @@ def dispatch( router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: cu_tokens_across_dp_cpu = get_forward_context( ).dp_metadata.cu_tokens_across_dp_cpu - hidden_states_across_dp = naive_multicast(hidden_states, - cu_tokens_across_dp_cpu) - router_logits_across_dp = naive_multicast(router_logits, - cu_tokens_across_dp_cpu) + hidden_states_across_dp = self.naive_multicast( + hidden_states, cu_tokens_across_dp_cpu) + router_logits_across_dp = self.naive_multicast( + router_logits, cu_tokens_across_dp_cpu) return hidden_states_across_dp, router_logits_across_dp def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: - dp_rank = get_dp_group().rank_in_group + if htorch.utils.internal.is_lazy(): + htorch.core.mark_step() cu_tokens_across_dp_cpu = get_forward_context( ).dp_metadata.cu_tokens_across_dp_cpu - start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[dp_rank - 1] - end = cu_tokens_across_dp_cpu[dp_rank] - all_hidden_states = get_dp_group().all_reduce(hidden_states) + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + + all_hidden_states = self.dp_group.all_reduce(hidden_states) hidden_states = all_hidden_states[start:end, :] return hidden_states diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 0694d960..774c43cf 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1412,11 +1412,8 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, # no real prefill batches num_prefill_batches = 0 - num_pad = self.get_dp_padding(num_prefill_batches) - if num_pad > 0: - for _ in range(num_pad): - all_batch_contents.append(BatchContents()) - return all_batch_contents + num_pad_across_dp = self.get_dp_padding(num_prefill_batches) + return all_batch_contents, num_pad_across_dp def _make_attn_bias(self, context_groups, token_groups): dtype = self.dtype @@ -1484,11 +1481,6 @@ def _form_prefill_batch(self, contents): target_bs, target_seq, target_blocks = self._get_prompt_bucketing_fn()( query_lens, num_context_blocks) - # dp aware padding - target_bs += self.get_dp_padding(target_bs) - target_seq += self.get_dp_padding(target_seq) - target_blocks += self.get_dp_padding(target_blocks) - # NOTE: If model does not support multimodal inputs, we pad here. # For models with multimodal support, we may want to get embeddings # for the valid tokens before padding. @@ -1581,35 +1573,49 @@ def _form_prefill_batch(self, contents): logits_indices=[logits_indices], logits_requests=[logits_requests]) + def _create_dummy_prefill_batch_contents( + self, num_prefills: int) -> list[PrefillInputData]: + req_id = -1 + context_len = 0 + query_len = 128 + prompt_tokens = 128 + token_ids = list(int(i) for i in range(prompt_tokens)) + num_blocks = round_up(context_len + query_len, + self.block_size) // self.block_size + blocks = [0] * num_blocks + num_output_logits = context_len + query_len - prompt_tokens + 1 + logits_positions = list(range(query_len - num_output_logits, + query_len)) + + new_batch_contents = BatchContents( + req_ids=[req_id], + token_ids=[token_ids], + context_lens=[context_len], + blocks=[blocks], + logits_positions=[logits_positions], + ) + + outputs = [ + self._form_prefill_batch(new_batch_contents) + for _ in range(num_prefills) + ] + return outputs + def _prepare_prefill_inputs( self, num_prefills, num_decodes, - num_scheduled_tokens: list[int]) -> PrefillInputData: - all_batch_contents = self._extract_prefill_batch_contents( + num_scheduled_tokens: list[int]) -> tuple[PrefillInputData, int]: + all_batch_contents, num_pad_across_dp = self._extract_prefill_batch_contents( num_prefills, num_decodes, num_scheduled_tokens) all_batches = [ self._form_prefill_batch(bc) for bc in all_batch_contents ] merge_contents(all_batches[0], *all_batches[1:]) - return all_batches[0] - - def _prepare_decode_inputs(self, - num_decodes, - num_scheduled_tokens, - scheduler_output=None) -> DecodeInputData: - # Decodes run as one single padded batch with shape [batch, 1] - # - # We need to set _PAD_SLOT_ID for the padding tokens in the - # slot_mapping, such that the attention KV cache insertion - # logic knows to ignore those indicies. Otherwise, the - # padding data can be dummy since we have a causal mask. - - block_table_cpu_tensor = self.input_batch.block_table[ - 0].get_cpu_tensor() - if num_decodes == 0: - return DecodeInputData(num_decodes=0) - # BLOCK_TABLE [batch, max_num_blocks_per_req] - context_lens = self.input_batch.num_computed_tokens_cpu[:num_decodes] + return all_batches[0], num_pad_across_dp + def _create_decode_input_data( + self, num_decodes, num_scheduled_tokens, context_lens, + block_table_cpu_tensor, num_computed_tokens_cpu, + token_ids_cpu) -> tuple[DecodeInputData, int]: # NOTE(kzawora): the +1 is what causes this entire thing to work, # as in the paged attention, we don't fetch just the context from cache, # but also kvs for the current token @@ -1622,7 +1628,8 @@ def _prepare_decode_inputs(self, num_decodes, sum(num_blocks))[0] # dp aware padding - padded_batch_size += self.get_dp_padding(padded_batch_size) + num_pad_across_dp = self.get_dp_padding(padded_batch_size) + padded_batch_size += num_pad_across_dp num_tokens_per_req = num_scheduled_tokens[:num_decodes] num_tokens = max(num_tokens_per_req) @@ -1812,7 +1819,42 @@ def _prepare_decode_inputs(self, block_size=self.block_size, query_start_loc=query_start_loc, ), - spec_decode_metadata=spec_decode_metadata) + spec_decode_metadata=spec_decode_metadata), num_pad_across_dp + + def _prepare_decode_inputs( + self, num_decodes, + num_scheduled_tokens) -> tuple[DecodeInputData, int]: + # Decodes run as one single padded batch with shape [batch, 1] + # + # We need to set _PAD_SLOT_ID for the padding tokens in the + # slot_mapping, such that the attention KV cache insertion + # logic knows to ignore those indicies. Otherwise, the + # padding data can be dummy since we have a causal mask. + + num_pad_across_dp = self.get_dp_padding(num_decodes) + if num_decodes == 0: + return DecodeInputData(num_decodes=0), num_pad_across_dp + # BLOCK_TABLE [batch, max_num_blocks_per_req] + context_lens = self.input_batch.num_computed_tokens_cpu[:num_decodes] + block_table_cpu_tensor = self.input_batch.block_table[ + 0].get_cpu_tensor() + return self._create_decode_input_data( + num_decodes, num_scheduled_tokens, context_lens, + block_table_cpu_tensor, self.input_batch.num_computed_tokens_cpu, + self.input_batch.token_ids_cpu) + + def _create_dummy_decode_input_data(self) -> DecodeInputData: + # create dummy decode input data with batch size 1 + context_lens = [128] + block_table_cpu_tensor = torch.zeros([self._PAD_BLOCK_ID], + dtype=torch.int32).reshape(1, -1) + num_computed_tokens_cpu = np.array([128], dtype=np.int32) + token_ids = np.array(list(int(i) for i in range(context_lens[0]))) + + return self._create_decode_input_data(1, [1], context_lens, + block_table_cpu_tensor, + num_computed_tokens_cpu, + token_ids)[0] def _get_cumsum_and_arange( self, @@ -1992,18 +2034,7 @@ def get_dp_padding(self, dp_size = self.vllm_config.parallel_config.data_parallel_size dp_rank = self.vllm_config.parallel_config.data_parallel_rank - # For DP: Don't pad when setting enforce_eager. - # This lets us set enforce_eager on the prefiller in a P/D setup and - # still use CUDA graphs (enabled by this padding) on the decoder. - # - # TODO(tms) : There are many cases where padding is enabled for - # prefills, causing unnecessary and excessive padding of activations. - - # skip padding for non PD disagg case to avoid padding on prefill batch - # size and decode batch size - if dp_size == 1 or self.vllm_config.model_config.enforce_eager or ( - self.vllm_config.kv_transfer_config is None - or self.vllm_config.kv_transfer_config.kv_connector is None): + if dp_size == 1: return 0 num_tokens_across_dp = DPMetadata.num_tokens_across_dp( @@ -2020,7 +2051,6 @@ def _execute_model_generic(self, warmup_mode=False, inputs_embeds=None, model_mm_kwargs=None): - # FORWARD. batch_size = token_ids.size(0) seq_len = self._seq_len(attn_metadata) @@ -2309,9 +2339,11 @@ def execute_model( num_prefills = len(pd_info.prompt_req_ids) num_reqs = num_decodes + num_prefills with self.profiler.record_event('internal', 'prepare_input_tensors'): - prefill_data, decode_data = self._prepare_inputs( + prefill_input_data, decode_input_data = self._prepare_inputs( scheduler_output, num_prefills, num_decodes) - # FIXME(kzawora): Currently there's no handling of logprobs. Fix that + prefill_data, num_pad_prefill_batch_across_dp = prefill_input_data + decode_data, num_pad_decode_batch_across_dp = decode_input_data + #FIXME(kzawora): Currently there's no handling of logprobs. Fix that # later. prefill_sampled_token_ids = [] prefill_sampled_requests = [] @@ -2377,6 +2409,7 @@ def execute_model( model_mm_kwargs=model_mm_kwargs, warmup_mode=warmup_mode) htorch.core.mark_step() + # Skip separate sampling for structured output if structured_output: logits_prompt.append(logits_device) @@ -2407,9 +2440,27 @@ def execute_model( prompt_batch_idx=idx, is_prompt=True) self.profiler.record_counter(self.event_start, counters) + if self.is_driver_worker and self.profiler.enabled: self.profiler_counter_helper.reset_prompt_seq_stats() + else: + if num_pad_prefill_batch_across_dp > 0: + htorch.core.mark_step() + dummy_prefill_input_data_list = self._create_dummy_prefill_batch_contents( + num_pad_prefill_batch_across_dp) + for dummy_prefill_input_data in dummy_prefill_input_data_list: + htorch.core.mark_step() + _, dummy_logits_device = \ + self._execute_model_generic( + dummy_prefill_input_data.token_ids[0], + dummy_prefill_input_data.position_ids[0], + dummy_prefill_input_data.attn_metadata[0], + dummy_prefill_input_data.logits_indices[0], + self.kv_caches, + warmup_mode=warmup_mode) + htorch.core.mark_step() + ######################### DECODES ######################### # Decodes run as one single batch with [padded_decode_bs, 1] if num_decodes > 0: @@ -2486,6 +2537,19 @@ def execute_model( prompt_batch_idx=None, is_prompt=False) self.profiler.record_counter(self.event_start, counters) + else: + if num_pad_decode_batch_across_dp > 0: + dummy_decode_input_data = self._create_dummy_decode_input_data( + ) + htorch.core.mark_step() + _, dummy_logits_device = self._execute_model_generic( + dummy_decode_input_data.token_ids, + dummy_decode_input_data.position_ids, + dummy_decode_input_data.attn_metadata, + dummy_decode_input_data.logits_indices, + self.kv_caches, + warmup_mode=warmup_mode) + htorch.core.mark_step() ################## Spec Decode ################## # work on spec decode if max_gen_len > 1 @@ -2641,6 +2705,7 @@ def execute_model( prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] pooler_output=[], ) + return model_runner_output def load_model(self) -> None: @@ -3097,8 +3162,8 @@ def __del__(self): @torch.inference_mode() def profile_run(self) -> None: + return """Profile to measure peak memory during forward pass.""" - # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value `None`. # the `dtype` argument does not matter, and we use `float32` as @@ -3112,7 +3177,6 @@ def profile_run(self) -> None: max_seq_len = math.ceil( (self.max_num_tokens // self.max_prefill_batch_size) / self.block_size) * self.block_size - max_seq_len = min(max_seq_len, self.max_model_len) self._execute_dummy_scenario( (self.max_prefill_batch_size, max_seq_len, 0), None) diff --git a/vllm_gaudi/v1/worker/hpu_worker.py b/vllm_gaudi/v1/worker/hpu_worker.py index bd28a8fc..7e694d24 100644 --- a/vllm_gaudi/v1/worker/hpu_worker.py +++ b/vllm_gaudi/v1/worker/hpu_worker.py @@ -178,16 +178,14 @@ def determine_available_memory(self) -> int: single_kv_block_size_bytes = 0 for layer_name, layer_spec in kv_cache_spec.items(): if isinstance(layer_spec, FullAttentionSpec): - # dtype = layer_spec.dtype + dtype = layer_spec.dtype # Use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value ``None``. - # hpu_k_cache = torch.tensor([], dtype=dtype, device='hpu') - # hpu_v_cache = torch.tensor([], dtype=dtype, device='hpu') + hpu_k_cache = torch.tensor([], dtype=dtype, device='hpu') + hpu_v_cache = torch.tensor([], dtype=dtype, device='hpu') - # kv_caches[layer_name] = (hpu_k_cache, hpu_v_cache) - # avoid issue of reading kv cache during profiling - kv_caches[layer_name] = None + kv_caches[layer_name] = (hpu_k_cache, hpu_v_cache) single_kv_block_size_bytes += layer_spec.page_size_bytes From 7a029da557ac51cc3e3100a47b1f72b9d01a1c7f Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Mon, 25 Aug 2025 10:06:06 +0300 Subject: [PATCH 05/13] add dp padding for prefill bs/seqlen/blocks Signed-off-by: Wuxun Zhang --- .../device_communicators/hpu_communicator.py | 28 ++++-- vllm_gaudi/v1/worker/hpu_model_runner.py | 86 +++++++++++-------- 2 files changed, 73 insertions(+), 41 deletions(-) diff --git a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py index 1d482306..b5a299ea 100644 --- a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py +++ b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py @@ -81,12 +81,28 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - hidden_states_across_dp = self.naive_multicast( - hidden_states, cu_tokens_across_dp_cpu) - router_logits_across_dp = self.naive_multicast( - router_logits, cu_tokens_across_dp_cpu) + input_size = hidden_states.size() + # Allocate output tensor. + output_size = list(input_size) + output_size[0] *= self.dp_world_size + hidden_states_across_dp = torch.empty(output_size, + dtype=hidden_states.dtype, + device=hidden_states.device) + torch.distributed.all_gather_into_tensor( + hidden_states_across_dp, + hidden_states, + group=self.dp_group.device_group) + + router_logits_size = router_logits.size() + router_logits_output_size = list(router_logits_size) + router_logits_output_size[0] *= self.dp_world_size + router_logits_across_dp = torch.empty(router_logits_output_size, + dtype=router_logits.dtype, + device=router_logits.device) + torch.distributed.all_gather_into_tensor( + router_logits_across_dp, + router_logits, + group=self.dp_group.device_group) return hidden_states_across_dp, router_logits_across_dp def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 774c43cf..3ef2fa3f 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1481,6 +1481,10 @@ def _form_prefill_batch(self, contents): target_bs, target_seq, target_blocks = self._get_prompt_bucketing_fn()( query_lens, num_context_blocks) + target_bs += self.get_dp_padding(target_bs) + target_seq += self.get_dp_padding(target_seq) + target_blocks += self.get_dp_padding(target_blocks) + # NOTE: If model does not support multimodal inputs, we pad here. # For models with multimodal support, we may want to get embeddings # for the valid tokens before padding. @@ -1602,15 +1606,23 @@ def _create_dummy_prefill_batch_contents( return outputs def _prepare_prefill_inputs( - self, num_prefills, num_decodes, - num_scheduled_tokens: list[int]) -> tuple[PrefillInputData, int]: + self, num_prefills, num_decodes, num_scheduled_tokens: list[int] + ) -> tuple[PrefillInputData, Optional[PrefillInputData]]: all_batch_contents, num_pad_across_dp = self._extract_prefill_batch_contents( num_prefills, num_decodes, num_scheduled_tokens) all_batches = [ self._form_prefill_batch(bc) for bc in all_batch_contents ] merge_contents(all_batches[0], *all_batches[1:]) - return all_batches[0], num_pad_across_dp + + dummy_prefill_input_batches = None + if num_pad_across_dp > 0: + dummy_prefill_input_batches = self._create_dummy_prefill_batch_contents( + num_pad_across_dp) + merge_contents(dummy_prefill_input_batches[0], + *dummy_prefill_input_batches[1:]) + return all_batches[0], dummy_prefill_input_batches[ + 0] if dummy_prefill_input_batches else None def _create_decode_input_data( self, num_decodes, num_scheduled_tokens, context_lens, @@ -1822,8 +1834,8 @@ def _create_decode_input_data( spec_decode_metadata=spec_decode_metadata), num_pad_across_dp def _prepare_decode_inputs( - self, num_decodes, - num_scheduled_tokens) -> tuple[DecodeInputData, int]: + self, num_decodes, num_scheduled_tokens + ) -> tuple[DecodeInputData, Optional[DecodeInputData]]: # Decodes run as one single padded batch with shape [batch, 1] # # We need to set _PAD_SLOT_ID for the padding tokens in the @@ -1833,28 +1845,31 @@ def _prepare_decode_inputs( num_pad_across_dp = self.get_dp_padding(num_decodes) if num_decodes == 0: - return DecodeInputData(num_decodes=0), num_pad_across_dp - # BLOCK_TABLE [batch, max_num_blocks_per_req] - context_lens = self.input_batch.num_computed_tokens_cpu[:num_decodes] - block_table_cpu_tensor = self.input_batch.block_table[ - 0].get_cpu_tensor() + if num_pad_across_dp > 0: + dummy_decode_input_data = self._create_dummy_decode_input_data( + ) + return DecodeInputData(num_decodes=0), dummy_decode_input_data + return DecodeInputData(num_decodes=0), None return self._create_decode_input_data( - num_decodes, num_scheduled_tokens, context_lens, - block_table_cpu_tensor, self.input_batch.num_computed_tokens_cpu, + num_decodes, num_scheduled_tokens, + self.input_batch.num_computed_tokens_cpu[:num_decodes], + self.input_batch.block_table[0].get_cpu_tensor(), + self.input_batch.num_computed_tokens_cpu, self.input_batch.token_ids_cpu) def _create_dummy_decode_input_data(self) -> DecodeInputData: # create dummy decode input data with batch size 1 + num_dummy_decodes = 1 + num_dummy_scheduled_tokens = [1] context_lens = [128] block_table_cpu_tensor = torch.zeros([self._PAD_BLOCK_ID], dtype=torch.int32).reshape(1, -1) num_computed_tokens_cpu = np.array([128], dtype=np.int32) token_ids = np.array(list(int(i) for i in range(context_lens[0]))) - return self._create_decode_input_data(1, [1], context_lens, - block_table_cpu_tensor, - num_computed_tokens_cpu, - token_ids)[0] + return self._create_decode_input_data( + num_dummy_decodes, num_dummy_scheduled_tokens, context_lens, + block_table_cpu_tensor, num_computed_tokens_cpu, token_ids)[0] def _get_cumsum_and_arange( self, @@ -1965,7 +1980,7 @@ def _prepare_inputs( scheduler_output: "SchedulerOutput", num_prefills, num_decodes, - ) -> tuple[PrefillInputData, Optional[DecodeInputData]]: + ): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -2341,8 +2356,10 @@ def execute_model( with self.profiler.record_event('internal', 'prepare_input_tensors'): prefill_input_data, decode_input_data = self._prepare_inputs( scheduler_output, num_prefills, num_decodes) - prefill_data, num_pad_prefill_batch_across_dp = prefill_input_data - decode_data, num_pad_decode_batch_across_dp = decode_input_data + prefill_data, dummy_prefill_input_data_batches_across_dp = prefill_input_data + num_pad_prefill_batch_across_dp = 0 if dummy_prefill_input_data_batches_across_dp is None else len( + dummy_prefill_input_data_batches_across_dp.request_ids) + decode_data, dummy_decode_input_data_across_dp = decode_input_data #FIXME(kzawora): Currently there's no handling of logprobs. Fix that # later. prefill_sampled_token_ids = [] @@ -2409,7 +2426,6 @@ def execute_model( model_mm_kwargs=model_mm_kwargs, warmup_mode=warmup_mode) htorch.core.mark_step() - # Skip separate sampling for structured output if structured_output: logits_prompt.append(logits_device) @@ -2446,17 +2462,19 @@ def execute_model( else: if num_pad_prefill_batch_across_dp > 0: - htorch.core.mark_step() - dummy_prefill_input_data_list = self._create_dummy_prefill_batch_contents( - num_pad_prefill_batch_across_dp) - for dummy_prefill_input_data in dummy_prefill_input_data_list: + for idx, ( + req_id, prompt_len, token_ids, position_ids, + attn_metadata, logits_indices, + logits_requests) in enumerate( + zip(*shallow_tuple( + dummy_prefill_input_data_batches_across_dp))): htorch.core.mark_step() _, dummy_logits_device = \ self._execute_model_generic( - dummy_prefill_input_data.token_ids[0], - dummy_prefill_input_data.position_ids[0], - dummy_prefill_input_data.attn_metadata[0], - dummy_prefill_input_data.logits_indices[0], + token_ids, + position_ids, + attn_metadata, + logits_indices, self.kv_caches, warmup_mode=warmup_mode) htorch.core.mark_step() @@ -2538,15 +2556,13 @@ def execute_model( is_prompt=False) self.profiler.record_counter(self.event_start, counters) else: - if num_pad_decode_batch_across_dp > 0: - dummy_decode_input_data = self._create_dummy_decode_input_data( - ) + if dummy_decode_input_data_across_dp is not None: htorch.core.mark_step() _, dummy_logits_device = self._execute_model_generic( - dummy_decode_input_data.token_ids, - dummy_decode_input_data.position_ids, - dummy_decode_input_data.attn_metadata, - dummy_decode_input_data.logits_indices, + dummy_decode_input_data_across_dp.token_ids, + dummy_decode_input_data_across_dp.position_ids, + dummy_decode_input_data_across_dp.attn_metadata, + dummy_decode_input_data_across_dp.logits_indices, self.kv_caches, warmup_mode=warmup_mode) htorch.core.mark_step() From 5d44aee51034f431de9da831513f3054c92562b1 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Mon, 25 Aug 2025 11:33:07 +0300 Subject: [PATCH 06/13] add dp into ci test Signed-off-by: Wuxun Zhang --- tests/full_tests/ci_tests.sh | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/full_tests/ci_tests.sh b/tests/full_tests/ci_tests.sh index 8f46eabd..18a63842 100644 --- a/tests/full_tests/ci_tests.sh +++ b/tests/full_tests/ci_tests.sh @@ -48,3 +48,13 @@ if [ $? -ne 0 ]; then exit -1 fi echo "Test with structured outputs passed" + +# DP2 +echo "Testing data parallel size 2 with vllm-hpu plugin v1" +echo HABANA_VISIBLE_DEVICES=all VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 python -u vllm-gaudi/examples/data_parallel.py --dp-size 2 --tp-size 2 +HABANA_VISIBLE_DEVICES=all VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 python -u vllm-gaudi/examples/data_parallel.py --dp-size 2 --tp-size 2 +if [ $? -ne 0 ]; then + echo "Error: Test failed for data parallel size 2" >&2 + exit -1 +fi +echo "Test with data parallel size 2 passed" From 37c4485666ae9f53d7bc948c56a419b28f534e46 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Tue, 26 Aug 2025 06:16:24 +0300 Subject: [PATCH 07/13] use reduce_scatter instead of all_reduce Signed-off-by: Wuxun Zhang --- .../device_communicators/hpu_communicator.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py index b5a299ea..f7d6b5c0 100644 --- a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py +++ b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py @@ -81,6 +81,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.dim() == 2, "Input hidden states must be 2D" input_size = hidden_states.size() # Allocate output tensor. output_size = list(input_size) @@ -108,13 +109,22 @@ def dispatch( def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: if htorch.utils.internal.is_lazy(): htorch.core.mark_step() + assert hidden_states.dim() == 2, "Input hidden states must be 2D" cu_tokens_across_dp_cpu = get_forward_context( ).dp_metadata.cu_tokens_across_dp_cpu - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] + # assume num tokens is padded across DP ranks + assert cu_tokens_across_dp_cpu[ + 0] * self.dp_world_size == cu_tokens_across_dp_cpu[-1] - all_hidden_states = self.dp_group.all_reduce(hidden_states) - hidden_states = all_hidden_states[start:end, :] + local_hidden_states = torch.empty( + (cu_tokens_across_dp_cpu[0], hidden_states.size(-1)), + device=hidden_states.device, + dtype=hidden_states.dtype) + + torch.distributed.reduce_scatter_tensor( + local_hidden_states, + hidden_states, + group=self.dp_group.device_group) + hidden_states = local_hidden_states return hidden_states From 97a84b03f5b90794869b22b40f56b230cfdbab45 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 27 Aug 2025 11:39:48 +0300 Subject: [PATCH 08/13] fix dummy prefill batch for eager Signed-off-by: Wuxun Zhang --- vllm_gaudi/v1/worker/hpu_model_runner.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 3ef2fa3f..e9c6a864 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -133,6 +133,14 @@ class BatchContents: def get_num_tokens(self): return [len(t) for t in self.token_ids] + def clone(self): + return BatchContents( + req_ids=self.req_ids.copy(), + token_ids=[t.copy() for t in self.token_ids], + context_lens=self.context_lens.copy(), + blocks=[b.copy() for b in self.blocks], + logits_positions=[lp.copy() for lp in self.logits_positions]) + # TODO(kzawora): remove this @dataclass @@ -1600,7 +1608,7 @@ def _create_dummy_prefill_batch_contents( ) outputs = [ - self._form_prefill_batch(new_batch_contents) + self._form_prefill_batch(new_batch_contents.clone()) for _ in range(num_prefills) ] return outputs From 213f54b6d71c1a693c696a97dea91d5735f57584 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 3 Sep 2025 11:22:24 +0300 Subject: [PATCH 09/13] fix rebase error Signed-off-by: Wuxun Zhang --- .../device_communicators/hpu_communicator.py | 22 ++------ vllm_gaudi/v1/worker/hpu_model_runner.py | 55 +++++++++++-------- 2 files changed, 35 insertions(+), 42 deletions(-) diff --git a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py index f7d6b5c0..e40602f3 100644 --- a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py +++ b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py @@ -8,7 +8,7 @@ from vllm.distributed.device_communicators.base_device_communicator \ import DeviceCommunicatorBase from vllm.forward_context import get_forward_context -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group import habana_frameworks.torch as htorch # noqa: F401 @@ -22,7 +22,7 @@ def __init__(self, unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) - self.dp_group = None + self.dp_group: Optional[GroupCoordinator] = None self.dp_rank = 0 self.dp_world_size = 1 # assume EP is enabled along with DP @@ -31,22 +31,6 @@ def __init__(self, self.dp_rank = self.dp_group.rank_in_group self.dp_world_size = self.dp_group.world_size - def naive_multicast(self, x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor) -> torch.Tensor: - assert x.dim() == 2, "Input tensor must be 2D" - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), - device=x.device, - dtype=x.dtype) - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - buffer[start:end, :].copy_(x) - for idx in range(self.dp_world_size): - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] - end = cu_tokens_across_dp_cpu[idx] - self.dp_group.broadcast(buffer[start:end, :], idx) - return buffer - def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used @@ -81,6 +65,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert self.dp_group is not None assert hidden_states.dim() == 2, "Input hidden states must be 2D" input_size = hidden_states.size() # Allocate output tensor. @@ -109,6 +94,7 @@ def dispatch( def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: if htorch.utils.internal.is_lazy(): htorch.core.mark_step() + assert self.dp_group is not None assert hidden_states.dim() == 2, "Input hidden states must be 2D" cu_tokens_across_dp_cpu = get_forward_context( ).dp_metadata.cu_tokens_across_dp_cpu diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index e9c6a864..3fe0ca45 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1587,7 +1587,7 @@ def _form_prefill_batch(self, contents): def _create_dummy_prefill_batch_contents( self, num_prefills: int) -> list[PrefillInputData]: - req_id = -1 + req_id = str(-1) context_len = 0 query_len = 128 prompt_tokens = 128 @@ -1616,8 +1616,9 @@ def _create_dummy_prefill_batch_contents( def _prepare_prefill_inputs( self, num_prefills, num_decodes, num_scheduled_tokens: list[int] ) -> tuple[PrefillInputData, Optional[PrefillInputData]]: - all_batch_contents, num_pad_across_dp = self._extract_prefill_batch_contents( - num_prefills, num_decodes, num_scheduled_tokens) + all_batch_contents, num_pad_across_dp = \ + self._extract_prefill_batch_contents( + num_prefills, num_decodes, num_scheduled_tokens) all_batches = [ self._form_prefill_batch(bc) for bc in all_batch_contents ] @@ -1625,17 +1626,20 @@ def _prepare_prefill_inputs( dummy_prefill_input_batches = None if num_pad_across_dp > 0: - dummy_prefill_input_batches = self._create_dummy_prefill_batch_contents( - num_pad_across_dp) + dummy_prefill_input_batches = \ + self._create_dummy_prefill_batch_contents(num_pad_across_dp) merge_contents(dummy_prefill_input_batches[0], *dummy_prefill_input_batches[1:]) return all_batches[0], dummy_prefill_input_batches[ 0] if dummy_prefill_input_batches else None def _create_decode_input_data( - self, num_decodes, num_scheduled_tokens, context_lens, - block_table_cpu_tensor, num_computed_tokens_cpu, - token_ids_cpu) -> tuple[DecodeInputData, int]: + self, + num_decodes, + num_scheduled_tokens, + context_lens, + block_table_cpu_tensor, + scheduler_output=None) -> tuple[DecodeInputData, int]: # NOTE(kzawora): the +1 is what causes this entire thing to work, # as in the paged attention, we don't fetch just the context from cache, # but also kvs for the current token @@ -1842,7 +1846,10 @@ def _create_decode_input_data( spec_decode_metadata=spec_decode_metadata), num_pad_across_dp def _prepare_decode_inputs( - self, num_decodes, num_scheduled_tokens + self, + num_decodes, + num_scheduled_tokens, + scheduler_output=None ) -> tuple[DecodeInputData, Optional[DecodeInputData]]: # Decodes run as one single padded batch with shape [batch, 1] # @@ -1861,9 +1868,7 @@ def _prepare_decode_inputs( return self._create_decode_input_data( num_decodes, num_scheduled_tokens, self.input_batch.num_computed_tokens_cpu[:num_decodes], - self.input_batch.block_table[0].get_cpu_tensor(), - self.input_batch.num_computed_tokens_cpu, - self.input_batch.token_ids_cpu) + self.input_batch.block_table[0].get_cpu_tensor(), scheduler_output) def _create_dummy_decode_input_data(self) -> DecodeInputData: # create dummy decode input data with batch size 1 @@ -1872,12 +1877,13 @@ def _create_dummy_decode_input_data(self) -> DecodeInputData: context_lens = [128] block_table_cpu_tensor = torch.zeros([self._PAD_BLOCK_ID], dtype=torch.int32).reshape(1, -1) - num_computed_tokens_cpu = np.array([128], dtype=np.int32) - token_ids = np.array(list(int(i) for i in range(context_lens[0]))) + # num_computed_tokens_cpu = np.array([128], dtype=np.int32) + # token_ids = np.array(list(int(i) for i in range(context_lens[0]))) - return self._create_decode_input_data( - num_dummy_decodes, num_dummy_scheduled_tokens, context_lens, - block_table_cpu_tensor, num_computed_tokens_cpu, token_ids)[0] + return self._create_decode_input_data(num_dummy_decodes, + num_dummy_scheduled_tokens, + context_lens, + block_table_cpu_tensor)[0] def _get_cumsum_and_arange( self, @@ -2052,8 +2058,7 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata, if not seen and not warmup_mode: logger.warning("Configuration: %s was not warmed-up!", cfg) - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + def get_dp_padding(self, num_tokens: int) -> int: dp_size = self.vllm_config.parallel_config.data_parallel_size dp_rank = self.vllm_config.parallel_config.data_parallel_rank @@ -2364,9 +2369,11 @@ def execute_model( with self.profiler.record_event('internal', 'prepare_input_tensors'): prefill_input_data, decode_input_data = self._prepare_inputs( scheduler_output, num_prefills, num_decodes) - prefill_data, dummy_prefill_input_data_batches_across_dp = prefill_input_data - num_pad_prefill_batch_across_dp = 0 if dummy_prefill_input_data_batches_across_dp is None else len( - dummy_prefill_input_data_batches_across_dp.request_ids) + prefill_data, \ + dummy_prefill_input_data_batches_across_dp = prefill_input_data + num_pad_prefill_batch_across_dp = \ + 0 if dummy_prefill_input_data_batches_across_dp is None \ + else len(dummy_prefill_input_data_batches_across_dp.request_ids) decode_data, dummy_decode_input_data_across_dp = decode_input_data #FIXME(kzawora): Currently there's no handling of logprobs. Fix that # later. @@ -2477,7 +2484,7 @@ def execute_model( zip(*shallow_tuple( dummy_prefill_input_data_batches_across_dp))): htorch.core.mark_step() - _, dummy_logits_device = \ + _, _, dummy_logits_device = \ self._execute_model_generic( token_ids, position_ids, @@ -2566,7 +2573,7 @@ def execute_model( else: if dummy_decode_input_data_across_dp is not None: htorch.core.mark_step() - _, dummy_logits_device = self._execute_model_generic( + _, _, dummy_logits_device = self._execute_model_generic( dummy_decode_input_data_across_dp.token_ids, dummy_decode_input_data_across_dp.position_ids, dummy_decode_input_data_across_dp.attn_metadata, From ea44413217650f8426936b83a2048df77eef5240 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 3 Sep 2025 16:47:06 +0300 Subject: [PATCH 10/13] fix ci error Signed-off-by: Wuxun Zhang --- vllm_gaudi/v1/worker/hpu_model_runner.py | 48 +++++++++++------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 3fe0ca45..1ba5ee23 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1633,13 +1633,12 @@ def _prepare_prefill_inputs( return all_batches[0], dummy_prefill_input_batches[ 0] if dummy_prefill_input_batches else None - def _create_decode_input_data( - self, - num_decodes, - num_scheduled_tokens, - context_lens, - block_table_cpu_tensor, - scheduler_output=None) -> tuple[DecodeInputData, int]: + def _create_decode_input_data(self, + num_decodes, + num_scheduled_tokens, + context_lens, + block_table_cpu_tensor, + scheduler_output=None) -> DecodeInputData: # NOTE(kzawora): the +1 is what causes this entire thing to work, # as in the paged attention, we don't fetch just the context from cache, # but also kvs for the current token @@ -1652,8 +1651,7 @@ def _create_decode_input_data( num_decodes, sum(num_blocks))[0] # dp aware padding - num_pad_across_dp = self.get_dp_padding(padded_batch_size) - padded_batch_size += num_pad_across_dp + padded_batch_size += self.get_dp_padding(padded_batch_size) num_tokens_per_req = num_scheduled_tokens[:num_decodes] num_tokens = max(num_tokens_per_req) @@ -1843,7 +1841,7 @@ def _create_decode_input_data( block_size=self.block_size, query_start_loc=query_start_loc, ), - spec_decode_metadata=spec_decode_metadata), num_pad_across_dp + spec_decode_metadata=spec_decode_metadata) def _prepare_decode_inputs( self, @@ -1868,7 +1866,8 @@ def _prepare_decode_inputs( return self._create_decode_input_data( num_decodes, num_scheduled_tokens, self.input_batch.num_computed_tokens_cpu[:num_decodes], - self.input_batch.block_table[0].get_cpu_tensor(), scheduler_output) + self.input_batch.block_table[0].get_cpu_tensor(), + scheduler_output), None def _create_dummy_decode_input_data(self) -> DecodeInputData: # create dummy decode input data with batch size 1 @@ -1877,13 +1876,10 @@ def _create_dummy_decode_input_data(self) -> DecodeInputData: context_lens = [128] block_table_cpu_tensor = torch.zeros([self._PAD_BLOCK_ID], dtype=torch.int32).reshape(1, -1) - # num_computed_tokens_cpu = np.array([128], dtype=np.int32) - # token_ids = np.array(list(int(i) for i in range(context_lens[0]))) - return self._create_decode_input_data(num_dummy_decodes, num_dummy_scheduled_tokens, context_lens, - block_table_cpu_tensor)[0] + block_table_cpu_tensor) def _get_cumsum_and_arange( self, @@ -2570,17 +2566,6 @@ def execute_model( prompt_batch_idx=None, is_prompt=False) self.profiler.record_counter(self.event_start, counters) - else: - if dummy_decode_input_data_across_dp is not None: - htorch.core.mark_step() - _, _, dummy_logits_device = self._execute_model_generic( - dummy_decode_input_data_across_dp.token_ids, - dummy_decode_input_data_across_dp.position_ids, - dummy_decode_input_data_across_dp.attn_metadata, - dummy_decode_input_data_across_dp.logits_indices, - self.kv_caches, - warmup_mode=warmup_mode) - htorch.core.mark_step() ################## Spec Decode ################## # work on spec decode if max_gen_len > 1 @@ -2617,6 +2602,17 @@ def execute_model( spec_decode_metadata, spec_decode_common_attn_metadata, decode_data)[:num_decodes] ################## Spec Decode end ################## + else: + if dummy_decode_input_data_across_dp is not None: + htorch.core.mark_step() + _, _, dummy_logits_device = self._execute_model_generic( + dummy_decode_input_data_across_dp.token_ids, + dummy_decode_input_data_across_dp.position_ids, + dummy_decode_input_data_across_dp.attn_metadata, + dummy_decode_input_data_across_dp.logits_indices, + self.kv_caches, + warmup_mode=warmup_mode) + htorch.core.mark_step() if structured_output: # Scheduler places cached before prompt From 3e1576484eee5d13c42512eca5b49412afacd4eb Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Mon, 8 Sep 2025 08:31:30 +0300 Subject: [PATCH 11/13] fix precommit issue Signed-off-by: Wuxun Zhang --- examples/data_parallel.py | 2 +- pyproject.toml | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/data_parallel.py b/examples/data_parallel.py index ea4f659a..097fb0e5 100644 --- a/examples/data_parallel.py +++ b/examples/data_parallel.py @@ -97,7 +97,7 @@ def parse_args(): def generate_random_token_ids(repeat=1) -> list[int]: """ - For testing different seuquence length in data parallel scenario + For testing different sequence length in data parallel scenario """ candidate_lens = [130, 560] prompts = [] diff --git a/pyproject.toml b/pyproject.toml index c8e34d99..5afe3ad9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,8 @@ include = ["vllm_gaudi"] [tool.yapfignore] ignore_patterns = [ - "build/**", + "build/**", + "examples/**", "vllm_gaudi/extension/**" # NOTE(kzawora): re-enable this once extension refactor is ready ] From daf05f25a1d3e9112ae5a2b447ec8686523d91dc Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Mon, 8 Sep 2025 15:36:32 +0300 Subject: [PATCH 12/13] address comments Signed-off-by: Wuxun Zhang --- vllm_gaudi/v1/worker/hpu_model_runner.py | 55 +++++++++++------------- vllm_gaudi/v1/worker/hpu_worker.py | 23 ---------- 2 files changed, 26 insertions(+), 52 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index d0b77a5a..ac808770 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2479,24 +2479,22 @@ def execute_model( if self.is_driver_worker and self.profiler.enabled: self.profiler_counter_helper.reset_prompt_seq_stats() - else: - if num_pad_prefill_batch_across_dp > 0: - for idx, ( - req_id, prompt_len, token_ids, position_ids, - attn_metadata, logits_indices, - logits_requests) in enumerate( - zip(*shallow_tuple( - dummy_prefill_input_data_batches_across_dp))): - htorch.core.mark_step() - _, _, dummy_logits_device = \ - self._execute_model_generic( - token_ids, - position_ids, - attn_metadata, - logits_indices, - self.kv_caches, - warmup_mode=warmup_mode) - htorch.core.mark_step() + elif num_pad_prefill_batch_across_dp > 0: + for idx, (req_id, prompt_len, token_ids, position_ids, + attn_metadata, logits_indices, + logits_requests) in enumerate( + zip(*shallow_tuple( + dummy_prefill_input_data_batches_across_dp))): + htorch.core.mark_step() + _, _, dummy_logits_device = \ + self._execute_model_generic( + token_ids, + position_ids, + attn_metadata, + logits_indices, + self.kv_caches, + warmup_mode=warmup_mode) + htorch.core.mark_step() ######################### DECODES ######################### # Decodes run as one single batch with [padded_decode_bs, 1] @@ -2610,17 +2608,16 @@ def execute_model( spec_decode_metadata, spec_decode_common_attn_metadata, decode_data)[:num_decodes] ################## Spec Decode end ################## - else: - if dummy_decode_input_data_across_dp is not None: - htorch.core.mark_step() - _, _, dummy_logits_device = self._execute_model_generic( - dummy_decode_input_data_across_dp.token_ids, - dummy_decode_input_data_across_dp.position_ids, - dummy_decode_input_data_across_dp.attn_metadata, - dummy_decode_input_data_across_dp.logits_indices, - self.kv_caches, - warmup_mode=warmup_mode) - htorch.core.mark_step() + elif dummy_decode_input_data_across_dp is not None: + htorch.core.mark_step() + _, _, dummy_logits_device = self._execute_model_generic( + dummy_decode_input_data_across_dp.token_ids, + dummy_decode_input_data_across_dp.position_ids, + dummy_decode_input_data_across_dp.attn_metadata, + dummy_decode_input_data_across_dp.logits_indices, + self.kv_caches, + warmup_mode=warmup_mode) + htorch.core.mark_step() if structured_output: # Scheduler places cached before prompt diff --git a/vllm_gaudi/v1/worker/hpu_worker.py b/vllm_gaudi/v1/worker/hpu_worker.py index 7e694d24..a79cc09d 100644 --- a/vllm_gaudi/v1/worker/hpu_worker.py +++ b/vllm_gaudi/v1/worker/hpu_worker.py @@ -322,29 +322,6 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - if torch.distributed.is_initialized(): - torch_world_size = torch.distributed.get_world_size() - expected_size = parallel_config.world_size *\ - parallel_config.data_parallel_size - if torch_world_size != expected_size: - raise RuntimeError( - "torch.distributed is already initialized but the torch world " - "size does not match parallel_config.world_size * " - "parallel_config.data_parallel_size " - f"({torch_world_size} vs. {expected_size}).") - elif not distributed_init_method: - raise ValueError( - "distributed_init_method must be set if torch.distributed " - "is not already initialized") - else: - backend = 'hccl' - torch.distributed.init_process_group( - backend=backend, - world_size=parallel_config.world_size, - rank=rank, - init_method=distributed_init_method, - ) - dummy_tensor_hpu = torch.ones(1).to('hpu') torch.distributed.all_reduce(dummy_tensor_hpu) assert dummy_tensor_hpu.item( From a7ca26424a79dd61523b86b4b1980dd6d3d7557d Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Tue, 9 Sep 2025 18:05:34 +0300 Subject: [PATCH 13/13] fix missing args Signed-off-by: Wuxun Zhang --- vllm_gaudi/v1/worker/hpu_model_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 8abae085..2e21397a 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2949,6 +2949,8 @@ def execute_model( attn_metadata, logits_indices, self.kv_caches, + None, + None, warmup_mode=warmup_mode) htorch.core.mark_step() @@ -3068,6 +3070,8 @@ def execute_model( dummy_decode_input_data_across_dp.attn_metadata, dummy_decode_input_data_across_dp.logits_indices, self.kv_caches, + None, + None, warmup_mode=warmup_mode) htorch.core.mark_step()