From c538d0fa2f9fc7620925706b1a7d63d480cef3e3 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Sat, 30 Aug 2025 02:15:24 +0300 Subject: [PATCH] update bucket Signed-off-by: Chendi.Xue --- vllm_gaudi/extension/bucketing/common.py | 8 ++++++-- vllm_gaudi/extension/bucketing/linear.py | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index 786d0d18..2d7d154c 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -10,7 +10,7 @@ from vllm_gaudi.extension.runtime import get_config -def calc_fallback_value(n: int, base_step: int): +def calc_fallback_value(n: int, base_step: int, warmup_max: int = None) -> int: """ Calculate next bucket for yet unbucketized value""" if n <= 1: return n @@ -30,6 +30,9 @@ def calc_fallback_value(n: int, base_step: int): # => bucket_size = ceil(4001^1/3) * 32 = 16 * 32 = 512 # => next_value = round_up(4001, 512) = 4096 bucket_size = math.ceil(math.pow(n, power)) * base_step + num_blocks = math.ceil(n / bucket_size) * bucket_size + if warmup_max is not None and num_blocks > warmup_max and warmup_max >= n: + bucket_size = warmup_max return math.ceil(n / bucket_size) * bucket_size @@ -121,7 +124,8 @@ def generate_fallback_bucket(self, batch_size, seq_len, ctx): if self.num_hpu_blocks is None: new_ctx = 0 else: - new_ctx = min(calc_fallback_value(ctx, self.fallback_blocks_base_step), + decode_block_max = self.decode_buckets[-1][2] if len(self.decode_buckets) > 0 else self.decode_block_max + new_ctx = min(calc_fallback_value(ctx, self.fallback_blocks_base_step, decode_block_max), self.num_hpu_blocks) return (new_batch_size, new_seq_len, new_ctx) diff --git a/vllm_gaudi/extension/bucketing/linear.py b/vllm_gaudi/extension/bucketing/linear.py index abd9a745..f0eaef99 100644 --- a/vllm_gaudi/extension/bucketing/linear.py +++ b/vllm_gaudi/extension/bucketing/linear.py @@ -197,6 +197,7 @@ def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, last_bucket = max_blocks for bs in bs_buckets: max_blocks_including_max_model_len = bs * math.ceil(max_model_len / block_size) + print(f"{bs=} {max_model_len=} {block_size=} {max_blocks_including_max_model_len=}, {block_buckets=}") for blocks in block_buckets: if bs > blocks: # Skip a dummy case when bs > blocks, which cannot occur in real execution @@ -204,6 +205,7 @@ def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, if not use_contiguous_pa and blocks > max_blocks_including_max_model_len: # Skip case when user wants to have bigger blocks than max model len # case cn only occur with contiguous PA + buckets.append((bs, 1, max_blocks_including_max_model_len)) continue if blocks >= last_bucket: buckets.append((bs, 1, last_bucket))