Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,14 @@ def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size,
self.fallback_seq_base_step = 32
self.fallback_blocks_base_step = 32

def get_bucketing_strategy(self):
def get_bucketing_strategy(self, file_path=None):
strategy = None
# TODO - we can use different strategies for decode and prompt
if file_path:
from vllm_gaudi.extension.bucketing.linear import (
FileBucketingStrategy)
strategy = FileBucketingStrategy()
return strategy
use_exponential_bucketing = True if \
get_config().VLLM_EXPONENTIAL_BUCKETING == None else \
get_config().VLLM_EXPONENTIAL_BUCKETING
Expand All @@ -76,7 +81,8 @@ def get_bucketing_strategy(self):

def generate_prompt_buckets(self):
if self.initialized:
strategy = self.get_bucketing_strategy()
prompt_buckets_file = get_context().VLLM_PROMPT_BUCKETING_FILE
strategy = self.get_bucketing_strategy(prompt_buckets_file)

self.prompt_buckets = strategy.get_prompt_buckets(
max_num_prefill_seqs = self.max_num_prefill_seqs,
Expand All @@ -91,7 +97,8 @@ def generate_prompt_buckets(self):

def generate_decode_buckets(self):
if self.initialized:
strategy = self.get_bucketing_strategy()
decode_buckets_file = get_context().VLLM_DECODE_BUCKETING_FILE
strategy = self.get_bucketing_strategy(decode_buckets_file)

self.decode_buckets = strategy.get_decode_buckets(
max_num_seqs = self.max_num_seqs,
Expand Down
67 changes: 67 additions & 0 deletions vllm_gaudi/extension/bucketing/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os.path


class FileBucketingStrategy():
'''
FileBucketingStrategy allows to read buckets from json file.
Files can be passed through flags:
- VLLM_PROMPT_BUCKETING_FILE
- VLLM_DECODE_BUCKETING_FILE
Valid files should have each bucket listed in new line in this order:
(batch_size, query_length, number_of_context_blocks)
'''
def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
max_num_batched_tokens, max_model_len):
all_buckets = read_buckets_file(True)

# Verify buckets - remove not valid
prompt_buckets = []
for bucket in all_buckets:
bs, query, ctx = bucket
if query + ctx * block_size > max_num_batched_tokens \
or bs > max_num_prefill_seqs
or (bs * math.celi(max_model_len / block_size) > max_model_len):
#TODO include conti pa
continue
prompt_buckets.append(bucket)

return sorted(prompt_buckets)

def get_decode_buckets(self, max_num_seqs, block_size,
max_num_batched_tokens, max_model_len,
num_max_blocks):
all_buckets = read_buckets_file(False)

return sorted(decode_buckets)


def read_buckets_file(is_prompt):
file_name = get_context().VLLM_PROMPT_BUCKETING_FILE if is_prompt \
else get_context().VLLM_DECODE_BUCKETING_FILE
phase = 'prompt' if is_prompt else 'decode'

assert os.path.isfile(file_name), \
"File for {phase} buckets config doesn't exist")

all_buckets = []
with open(file_name, "r") as f:
for line in f:
bucket = line.strip()
if not bucket or not bucket[0].isdigit():
continue
values = [b for b in bucket.replace(",", " ").split() if b]

try:
new_bucket = list(map(int, values))
except ValueError:
continue

if len(new_bucket) == 3:
all_buckets.append(tuple(new_bucket))
elif len(new_bucket) == 2:
all_buckets.append(tuple(new_bucket[0], new_bucket[1], 0))
# skip other invaid configs

if len(all_buckets) < 1:
logger().info(f"No buckets found in {file_name} file for {phase}")
return all_buckets
2 changes: 2 additions & 0 deletions vllm_gaudi/extension/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def get_user_flags():
Env('VLLM_USE_V1', boolean),
Env('VLLM_ENABLE_EXPERIMENTAL_FLAGS', boolean),
Env('VLLM_EXPONENTIAL_BUCKETING', boolean),
Env('VLLM_PROMPT_BUCKETING_FILE', str),
Env('VLLM_DECODE_BUCKETING_FILE', str),
Env('VLLM_PROMPT_BS_BUCKET_MIN', int),
Env('VLLM_PROMPT_BS_BUCKET_STEP', int),
Env('VLLM_PROMPT_BS_BUCKET_MAX', int),
Expand Down
Loading