From b903cd80f2fc101b891fb42cd006e6d2accb3b22 Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Sun, 24 Aug 2025 00:26:16 +0300 Subject: [PATCH] Initial commit Signed-off-by: Agata Dobrzyniewicz --- vllm_gaudi/extension/bucketing/common.py | 13 +++-- vllm_gaudi/extension/bucketing/file.py | 67 ++++++++++++++++++++++++ vllm_gaudi/extension/features.py | 2 + 3 files changed, 79 insertions(+), 3 deletions(-) create mode 100644 vllm_gaudi/extension/bucketing/file.py diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index db6ae1bf..f5d518f7 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -58,9 +58,14 @@ def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size, self.fallback_seq_base_step = 32 self.fallback_blocks_base_step = 32 - def get_bucketing_strategy(self): + def get_bucketing_strategy(self, file_path=None): strategy = None # TODO - we can use different strategies for decode and prompt + if file_path: + from vllm_gaudi.extension.bucketing.linear import ( + FileBucketingStrategy) + strategy = FileBucketingStrategy() + return strategy use_exponential_bucketing = True if \ get_config().VLLM_EXPONENTIAL_BUCKETING == None else \ get_config().VLLM_EXPONENTIAL_BUCKETING @@ -76,7 +81,8 @@ def get_bucketing_strategy(self): def generate_prompt_buckets(self): if self.initialized: - strategy = self.get_bucketing_strategy() + prompt_buckets_file = get_context().VLLM_PROMPT_BUCKETING_FILE + strategy = self.get_bucketing_strategy(prompt_buckets_file) self.prompt_buckets = strategy.get_prompt_buckets( max_num_prefill_seqs = self.max_num_prefill_seqs, @@ -91,7 +97,8 @@ def generate_prompt_buckets(self): def generate_decode_buckets(self): if self.initialized: - strategy = self.get_bucketing_strategy() + decode_buckets_file = get_context().VLLM_DECODE_BUCKETING_FILE + strategy = self.get_bucketing_strategy(decode_buckets_file) self.decode_buckets = strategy.get_decode_buckets( max_num_seqs = self.max_num_seqs, diff --git a/vllm_gaudi/extension/bucketing/file.py b/vllm_gaudi/extension/bucketing/file.py new file mode 100644 index 00000000..0cf74786 --- /dev/null +++ b/vllm_gaudi/extension/bucketing/file.py @@ -0,0 +1,67 @@ +import os.path + + +class FileBucketingStrategy(): + ''' + FileBucketingStrategy allows to read buckets from json file. + Files can be passed through flags: + - VLLM_PROMPT_BUCKETING_FILE + - VLLM_DECODE_BUCKETING_FILE + Valid files should have each bucket listed in new line in this order: + (batch_size, query_length, number_of_context_blocks) + ''' + def get_prompt_buckets(self, max_num_prefill_seqs, block_size, + max_num_batched_tokens, max_model_len): + all_buckets = read_buckets_file(True) + + # Verify buckets - remove not valid + prompt_buckets = [] + for bucket in all_buckets: + bs, query, ctx = bucket + if query + ctx * block_size > max_num_batched_tokens \ + or bs > max_num_prefill_seqs + or (bs * math.celi(max_model_len / block_size) > max_model_len): + #TODO include conti pa + continue + prompt_buckets.append(bucket) + + return sorted(prompt_buckets) + + def get_decode_buckets(self, max_num_seqs, block_size, + max_num_batched_tokens, max_model_len, + num_max_blocks): + all_buckets = read_buckets_file(False) + + return sorted(decode_buckets) + + +def read_buckets_file(is_prompt): + file_name = get_context().VLLM_PROMPT_BUCKETING_FILE if is_prompt \ + else get_context().VLLM_DECODE_BUCKETING_FILE + phase = 'prompt' if is_prompt else 'decode' + + assert os.path.isfile(file_name), \ + "File for {phase} buckets config doesn't exist") + + all_buckets = [] + with open(file_name, "r") as f: + for line in f: + bucket = line.strip() + if not bucket or not bucket[0].isdigit(): + continue + values = [b for b in bucket.replace(",", " ").split() if b] + + try: + new_bucket = list(map(int, values)) + except ValueError: + continue + + if len(new_bucket) == 3: + all_buckets.append(tuple(new_bucket)) + elif len(new_bucket) == 2: + all_buckets.append(tuple(new_bucket[0], new_bucket[1], 0)) + # skip other invaid configs + + if len(all_buckets) < 1: + logger().info(f"No buckets found in {file_name} file for {phase}") + return all_buckets diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index 8081582f..34f9a243 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -14,6 +14,8 @@ def get_user_flags(): Env('VLLM_USE_V1', boolean), Env('VLLM_ENABLE_EXPERIMENTAL_FLAGS', boolean), Env('VLLM_EXPONENTIAL_BUCKETING', boolean), + Env('VLLM_PROMPT_BUCKETING_FILE', str), + Env('VLLM_DECODE_BUCKETING_FILE', str), Env('VLLM_PROMPT_BS_BUCKET_MIN', int), Env('VLLM_PROMPT_BS_BUCKET_STEP', int), Env('VLLM_PROMPT_BS_BUCKET_MAX', int),