Skip to content

Commit 3e725c3

Browse files
IzzyPuttermanjQizhang
authored andcommitted
Specdec Bench: vLLM reqid, SGL path, conc > 1 metric fix (NVIDIA#541)
## What does this PR do? **SGLang** Fix for actually passing the draft model path to the engine **vLLM** Fix for multiturn to not overlap request_id strings **Acceptance Rate** Fix for potential race condition on multiturn datasets in writing back AR **Overview:** ? ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> Signed-off-by: Izzy Putterman <[email protected]>
1 parent 263007a commit 3e725c3

File tree

12 files changed

+48
-39
lines changed

12 files changed

+48
-39
lines changed

examples/specdec_bench/run.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ async def process_single_request(request, i):
4949
if request.system_prompt is not None:
5050
messages.append({"role": "system", "content": request.system_prompt})
5151

52-
for question in request.turns:
52+
for turn_id, question in enumerate(request.turns):
5353
messages.append({"role": "user", "content": question})
5454
entry_encoded = encode_chat(tokenizer, messages)
5555

5656
# Run the async runner.run directly
57-
output_tokens = await runner.run(entry_encoded, max_length, end_id, i)
57+
output_tokens = await runner.run(
58+
entry_encoded, max_length, end_id, request_id=i, turn_id=turn_id
59+
)
5860
output_text = decode_chat(tokenizer, output_tokens["output_ids"][0])
5961
output_text = postprocess(output_text)
6062
messages.append({"role": "assistant", "content": output_text})

examples/specdec_bench/specdec_bench/metrics/aa_timing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, base_tokenizer):
3434
self.base_tokenizer = base_tokenizer
3535
self.total_tokens = []
3636

37-
def process_step(self, step_outputs, new_turn=True):
37+
def process_step(self, step_outputs, request_id, turn_id):
3838
self.timing.append(step_outputs["token_times"])
3939
target_tokens = [
4040
t for tok_list in step_outputs["output_ids"] for tok in tok_list for t in tok

examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@
2222
class AcceptanceRate(Metric):
2323
def __init__(self):
2424
super().__init__()
25-
self.prompt_ar = []
25+
self.prompt_ar = {}
2626
self.name = "acceptance_rate"
2727

28-
def process_step(self, step_outputs, new_turn=True):
29-
if new_turn:
30-
self.prompt_ar.append([])
28+
def process_step(self, step_outputs, request_id, turn_id):
29+
if request_id not in self.prompt_ar:
30+
self.prompt_ar[request_id] = {}
31+
if turn_id not in self.prompt_ar[request_id]:
32+
self.prompt_ar[request_id][turn_id] = []
3133
for i, beam_output in enumerate(step_outputs["output_ids"]):
3234
for output_id_iter in beam_output:
33-
self.prompt_ar[-1].append(len(output_id_iter))
35+
self.prompt_ar[request_id][turn_id].append(len(output_id_iter))
3436

3537
def _get_lengths(self, turn, lengths):
3638
for j in turn:
@@ -55,16 +57,19 @@ def _process_lengths(self, lengths):
5557
running_len -= v
5658

5759
def process_final(self, text_outputs):
58-
i = 0
60+
all_ar = []
5961
lengths = {}
6062
self.out["Request_AR"] = {}
61-
while i < len(self.prompt_ar):
62-
turn_1 = self.prompt_ar[i]
63-
self.out["Request_AR"][i] = sum(turn_1) / len(turn_1)
64-
self._get_lengths(turn_1, lengths)
65-
print(i, self.out["Request_AR"][i])
66-
i += 1
67-
average_ar = sum(self.out["Request_AR"].values()) / len(self.out["Request_AR"])
63+
self.prompt_ar = dict(sorted(self.prompt_ar.items(), key=lambda x: x[0]))
64+
for request_id, turns in self.prompt_ar.items():
65+
self.out["Request_AR"][request_id] = {}
66+
for turn_id, turn in turns.items():
67+
ar = sum(turn) / len(turn)
68+
self.out["Request_AR"][request_id][turn_id] = ar
69+
all_ar.append(ar)
70+
self._get_lengths(turn, lengths)
71+
print(request_id, turn_id, self.out["Request_AR"][request_id][turn_id])
72+
average_ar = sum(all_ar) / len(all_ar)
6873
print("Average AR:", average_ar)
6974
self.out["Average_AR"] = average_ar
7075
self._process_lengths(lengths)

examples/specdec_bench/specdec_bench/metrics/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self):
2424
self.out = {}
2525
self.name = "metric"
2626

27-
def process_step(self, step_outputs, new_turn=True):
27+
def process_step(self, step_outputs, request_id, turn_id):
2828
raise NotImplementedError
2929

3030
def process_final(self, text_outputs):

examples/specdec_bench/specdec_bench/metrics/mtbench.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@ def process_final(self, text_outputs):
3535
i = 0
3636
lengths = {}
3737
self.out["Request_AR"] = {}
38-
while i < len(self.prompt_ar):
39-
turn_1 = self.prompt_ar[i]
40-
turn_2 = self.prompt_ar[i + 1]
41-
q_id = i // 2
38+
self.prompt_ar = dict(sorted(self.prompt_ar.items(), key=lambda x: x[0]))
39+
for request_id, turns in self.prompt_ar.items():
40+
turn_1 = turns[0]
41+
turn_2 = turns[1]
42+
q_id = request_id
4243
mtbench_topic = MTBENCH_TOPICS[q_id // 10]
43-
self.out["Request_AR"][q_id] = sum(turn_1 + turn_2) / len(turn_1 + turn_2)
44+
self.out["Request_AR"][request_id] = sum(turn_1 + turn_2) / len(turn_1 + turn_2)
4445
self._get_lengths(turn_1, lengths)
4546
self._get_lengths(turn_2, lengths)
4647
print(mtbench_topic, sum(turn_1 + turn_2) / len(turn_1 + turn_2))
47-
i += 2
4848
per_category = [[] for _ in range(len(MTBENCH_TOPICS))]
4949
for q_id, ar in self.out["Request_AR"].items():
5050
per_category[q_id // 10].append(ar)

examples/specdec_bench/specdec_bench/metrics/timing.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, tp_size):
2626
self.total_tokens = []
2727
self.tp_size = tp_size
2828

29-
def process_step(self, step_outputs, new_turn=True):
29+
def process_step(self, step_outputs, request_id, turn_id):
3030
self.timing.append(step_outputs["token_times"])
3131
self.total_tokens.append(
3232
sum([sum([len(j) for j in i]) for i in step_outputs["output_ids"]])
@@ -42,8 +42,9 @@ def process_final(self, text_outputs):
4242
self.out["Output TPS"] = sum(self.total_tokens) / (end_time - start_time)
4343
self.out["Output TPS/gpu"] = self.out["Output TPS"] / self.tp_size
4444
for tokens, times in zip(self.total_tokens, self.timing):
45-
e2e_time.append(times[-1] - times[0])
46-
ttft_time.append(times[1] - times[0])
45+
if len(times) > 1:
46+
e2e_time.append(times[-1] - times[0])
47+
ttft_time.append(times[1] - times[0])
4748
if len(times) > 2:
4849
gen_tp_time.append((tokens - 1) / (times[-1] - times[1]))
4950
tpot_time.extend([a - b for a, b in zip(times[1:], times[:-1])])

examples/specdec_bench/specdec_bench/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Model:
1818
def __init__(self, model_dir, tokenizer, max_draft_length):
1919
raise NotImplementedError
2020

21-
async def run(self, prompt_ids, max_length, end_id, request_id):
21+
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
2222
"""
2323
prompt_ids is list of tokens
2424
output is list of list of tokens

examples/specdec_bench/specdec_bench/models/sglang.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
speculative_num_steps=kwargs.get("speculative_num_steps", 3),
5151
speculative_eagle_topk=kwargs.get("speculative_eagle_topk", 1),
5252
speculative_num_draft_tokens=kwargs.get("speculative_num_draft_tokens", 4),
53+
speculative_draft_model_path=kwargs.get("draft_model_dir"),
5354
torch_compile_max_bs=max_concurrent_requests,
5455
attention_backend=kwargs.get("attention_backend"),
5556
enable_torch_compile=kwargs.get("enable_torch_compile", False),
@@ -70,7 +71,7 @@ def __init__(
7071

7172
self.sampling_config = sampling_kwargs
7273

73-
async def run(self, prompt_ids, max_length, end_id, request_id):
74+
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
7475
timing = []
7576
output_dict = {}
7677
self.sampling_config["max_new_tokens"] = max_length

examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
self.model = create_executor(model_path, max_concurrent_requests, kwargs)
4444
self.sampling_kwargs = sampling_kwargs
4545

46-
async def run(self, prompt_ids, max_length, end_id, request_id):
46+
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
4747
output_dict = {}
4848
sampling_config = check_sampling_config(self.sampling_kwargs, max_length, end_id)
4949
outputs = []

examples/specdec_bench/specdec_bench/models/vllm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs
8484
self.loop = asyncio.new_event_loop()
8585
asyncio.set_event_loop(self.loop)
8686

87-
async def run(self, prompt_ids, max_length, end_id, request_id):
87+
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
8888
output_dict = {}
8989
self.sampling_config.max_tokens = max_length
9090
self.sampling_config.stop_token_ids = [end_id]
9191

92-
outputs, timing, full_tokens = await self.generate(prompt_ids, request_id)
92+
outputs, timing, full_tokens = await self.generate(prompt_ids, request_id, turn_id)
9393

9494
reformatted_output_ids = [[] for _ in range(self.sampling_kwargs.get("beam_width", 1))]
9595
start = 0
@@ -114,13 +114,13 @@ async def run(self, prompt_ids, max_length, end_id, request_id):
114114
]
115115
return output_dict
116116

117-
async def generate(self, prompt_ids, request_id):
117+
async def generate(self, prompt_ids, request_id, turn_id):
118118
timing = []
119119
timing.append(time.perf_counter())
120120
outputs = []
121121
full_tokens = []
122122
async for output in self.model.generate(
123-
request_id=str(request_id),
123+
request_id=f"{request_id}.{turn_id}",
124124
prompt=TokensPrompt(prompt_token_ids=prompt_ids),
125125
sampling_params=self.sampling_config,
126126
):

0 commit comments

Comments
 (0)