Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions tensorrt_llm/scaffolding/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import time
from typing import List, Mapping, Tuple, Type
from typing import List, Mapping, Optional, Tuple, Type

from pydantic import BaseModel

from tensorrt_llm.scaffolding.load_generation_strategy import \
LoadGenerationStrategy
from tensorrt_llm.scaffolding.scaffolding_llm import (ScaffoldingLlm,
ScaffoldingResult)
from tensorrt_llm.scaffolding.task_collection import (TaskCollection,
Expand All @@ -21,26 +23,37 @@ async def enqueue_requests(input_queue, requests):
await input_queue.put(None)


async def process_request(scaffolding_llm, request, output_queue, semaphore):
async def process_request(scaffolding_llm: ScaffoldingLlm, request,
output_queue, target_time, semaphore):
async with semaphore:
wait_time = target_time - time.time()
if wait_time > 0:
await asyncio.sleep(wait_time)

request_start_time = time.time()
result = scaffolding_llm.generate_async(request.prompt)
await result.aresult()
request_execution_time = time.time() - request_start_time
await output_queue.put((result, request_execution_time))
await output_queue.put(
(result, request_start_time, request_execution_time))


async def run_scaffolding_llm(scaffolding_llm, input_queue, output_queue,
concurrency):
semaphore = asyncio.Semaphore(concurrency)
strategy: LoadGenerationStrategy):
semaphore = strategy.get_semaphore()
time_generator = strategy.request_times()

tasks = set()

while True:
request = await input_queue.get()
if request is None:
break
target_time = await time_generator.__anext__()

task = asyncio.create_task(
process_request(scaffolding_llm, request, output_queue, semaphore))
process_request(scaffolding_llm, request, output_queue, target_time,
semaphore))
tasks.add(task)
task.add_done_callback(tasks.discard)

Expand All @@ -63,10 +76,19 @@ def wrapper_prototype_controller_with_task_collection(scaffolding_llm,


async def async_scaffolding_benchmark(
scaffolding_llm: ScaffoldingLlm,
task_collection_types: Mapping[str, Type[TaskCollection]],
requests: List[ScaffoldingBenchRequest],
concurrency: int) -> Tuple[List[ScaffoldingResult], List[float], float]:
scaffolding_llm: ScaffoldingLlm,
task_collection_types: Mapping[str, Type[TaskCollection]],
requests: List[ScaffoldingBenchRequest],
concurrency: Optional[int] = None,
strategy: Optional[LoadGenerationStrategy] = None
) -> Tuple[List[ScaffoldingResult], List[float], List[float], float]:
if strategy is None:
if concurrency is None:
raise ValueError("Must provide either 'strategy' or 'concurrency'")
from tensorrt_llm.scaffolding.load_generation_strategy import \
ConcurrentStrategy
strategy = ConcurrentStrategy(concurrency=concurrency)

wrapper_prototype_controller_with_task_collection(scaffolding_llm,
task_collection_types)

Expand All @@ -76,20 +98,23 @@ async def async_scaffolding_benchmark(
start_time = time.time()
results = []
requests_execution_time = []
requests_start_time = []

enqueue_task = asyncio.create_task(enqueue_requests(input_queue, requests))

run_scaffolding_llm_task = asyncio.create_task(
run_scaffolding_llm(scaffolding_llm, input_queue, output_queue,
concurrency))
strategy))

while True:
try:
item = await asyncio.wait_for(output_queue.get(), timeout=1.0)
if item is None:
break
result, request_execution_time = item
result, request_start_time, request_execution_time = item
results.append(result)
requests_execution_time.append(request_execution_time)
requests_start_time.append(request_start_time)
except asyncio.TimeoutError:
continue

Expand All @@ -98,4 +123,4 @@ async def async_scaffolding_benchmark(
enqueue_task.result()
run_scaffolding_llm_task.result()

return results, requests_execution_time, total_time
return results, requests_start_time, requests_execution_time, total_time
Loading