Skip to content

Commit 7fca8f2

Browse files
committed
Add AR calculation to benchmark_serving
Signed-off-by: Zero Zeng <[email protected]>
1 parent 48ddc3d commit 7fca8f2

File tree

2 files changed

+70
-9
lines changed

2 files changed

+70
-9
lines changed

tensorrt_llm/serve/scripts/backend_request_func.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class RequestFuncOutput:
4444
tpot: float = 0.0 # avg next-token latencies
4545
prompt_len: int = 0
4646
error: str = ""
47+
decode_iteration: int = 0 # Number of decoding iterations
4748

4849

4950
async def async_request_trt_llm(
@@ -77,6 +78,7 @@ async def async_request_trt_llm(
7778
ttft = 0.0
7879
st = time.perf_counter()
7980
most_recent_timestamp = st
81+
decode_iteration_count = 0 # Track decoding iterations
8082
try:
8183
async with request_session.post(url=api_url, json=payload) as response:
8284
if response.status == 200:
@@ -102,16 +104,22 @@ async def async_request_trt_llm(
102104
else:
103105
output.itl.append(timestamp - most_recent_timestamp)
104106

107+
# Increment decode iteration for each chunk
108+
decode_iteration_count += 1
105109
most_recent_timestamp = timestamp
106110

107111
output.latency = most_recent_timestamp - st
112+
output.decode_iteration = decode_iteration_count
108113
else:
109114
content = await response.content.read()
110115
data = json.loads(content.decode())
111116
output.ttft = -1
112117
output.itl = []
113118
output.generated_text = data["text_output"]
114119
output.latency = time.perf_counter() - st
120+
# For non-streaming, estimate decode_iteration as number of output tokens
121+
output.decode_iteration = len(output.generated_text.split(
122+
)) if output.generated_text else 1
115123

116124
else:
117125
output.error = response.reason or ""
@@ -170,6 +178,7 @@ async def async_request_openai_completions(
170178
generated_text = ""
171179
st = time.perf_counter()
172180
most_recent_timestamp = st
181+
decode_iteration_count = 0 # Track decoding iterations
173182
try:
174183
async with request_session.post(url=api_url,
175184
json=payload,
@@ -206,6 +215,9 @@ async def async_request_openai_completions(
206215
output.itl.append(timestamp -
207216
most_recent_timestamp)
208217

218+
# Increment decode iteration for each chunk with text
219+
if text is not None:
220+
decode_iteration_count += 1
209221
most_recent_timestamp = timestamp
210222
generated_text += text or ""
211223
elif usage := data.get("usage"):
@@ -220,6 +232,7 @@ async def async_request_openai_completions(
220232
"This response will be marked as failed!")
221233
output.generated_text = generated_text
222234
output.latency = most_recent_timestamp - st
235+
output.decode_iteration = decode_iteration_count
223236
else:
224237
content = await response.content.read()
225238
data = json.loads(content.decode())
@@ -230,6 +243,8 @@ async def async_request_openai_completions(
230243
output.ttft = -1
231244
output.itl = []
232245
output.output_tokens = data["usage"]["completion_tokens"]
246+
# For non-streaming, estimate decode_iteration as number of output tokens
247+
output.decode_iteration = output.output_tokens if output.output_tokens > 0 else 1
233248
else:
234249
output.error = response.reason or ""
235250
output.success = False
@@ -306,6 +321,7 @@ async def async_request_openai_chat_completions(
306321
ttft = 0.0
307322
st = time.perf_counter()
308323
most_recent_timestamp = st
324+
decode_iteration_count = 0 # Track decoding iterations
309325
try:
310326
async with request_session.post(url=api_url,
311327
json=payload,
@@ -336,6 +352,9 @@ async def async_request_openai_chat_completions(
336352
output.itl.append(timestamp -
337353
most_recent_timestamp)
338354

355+
# Increment decode iteration for each chunk with content
356+
if content is not None:
357+
decode_iteration_count += 1
339358
generated_text += content or ""
340359
elif usage := data.get("usage"):
341360
output.output_tokens = usage.get(
@@ -345,6 +364,7 @@ async def async_request_openai_chat_completions(
345364

346365
output.generated_text = generated_text
347366
output.latency = most_recent_timestamp - st
367+
output.decode_iteration = decode_iteration_count
348368
else:
349369
content = await response.content.read()
350370
data = json.loads(content.decode())
@@ -354,6 +374,8 @@ async def async_request_openai_chat_completions(
354374
output.itl = []
355375
output.latency = time.perf_counter() - st
356376
output.ttft = -1
377+
# For non-streaming, estimate decode_iteration as number of output tokens
378+
output.decode_iteration = output.output_tokens if output.output_tokens > 0 else 1
357379

358380
else:
359381
output.error = response.reason or ""

tensorrt_llm/serve/scripts/benchmark_serving.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ class BenchmarkMetrics:
7979
std_e2el_ms: float
8080
percentiles_e2el_ms: list[tuple[float, float]]
8181
tput_user: list[float]
82+
# Request accuracy rate metrics
83+
mean_request_ar: float
84+
median_request_ar: float
85+
std_request_ar: float
86+
percentiles_request_ar: list[tuple[float, float]]
8287

8388

8489
async def get_request(
@@ -131,7 +136,7 @@ def calculate_metrics(
131136
selected_percentile_metrics: list[str],
132137
selected_percentiles: list[float],
133138
goodput_config_dict: dict[str, float],
134-
) -> tuple[BenchmarkMetrics, list[int]]:
139+
) -> tuple[BenchmarkMetrics, list[int], list[float]]:
135140
actual_output_lens: list[int] = []
136141
total_input = 0
137142
completed = 0
@@ -142,6 +147,7 @@ def calculate_metrics(
142147
ttfts: list[float] = []
143148
e2els: list[float] = []
144149
tput_user: list[float] = []
150+
request_ars: list[float] = [] # Request accuracy rates
145151
for i in range(len(outputs)):
146152
if outputs[i].success:
147153
output_len = outputs[i].output_tokens
@@ -167,9 +173,24 @@ def calculate_metrics(
167173
ttfts.append(outputs[i].ttft)
168174
e2els.append(outputs[i].latency)
169175
tput_user.append(output_len / (outputs[i].latency))
176+
177+
# Calculate request accuracy rate (num_generated_tokens / (decode_iteration + 1))
178+
decode_iter = outputs[i].decode_iteration
179+
if decode_iter >= 0:
180+
# For generated tokens, we use output_len - 1 (excluding the first token if needed)
181+
# But according to the reference, it should be num_generated_tokens
182+
num_generated_tokens = max(0, output_len -
183+
1) if output_len > 1 else output_len
184+
request_ar = num_generated_tokens / (
185+
decode_iter + 1) if decode_iter >= 0 else 0.0
186+
request_ars.append(request_ar)
187+
else:
188+
request_ars.append(0.0)
189+
170190
completed += 1
171191
else:
172192
actual_output_lens.append(0)
193+
request_ars.append(0.0)
173194

174195
if goodput_config_dict:
175196
valid_metrics = []
@@ -228,8 +249,13 @@ def calculate_metrics(
228249
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
229250
for p in selected_percentiles],
230251
tput_user=np.mean(tput_user or 0),
252+
mean_request_ar=np.mean(request_ars or 0),
253+
median_request_ar=np.median(request_ars or 0),
254+
std_request_ar=np.std(request_ars or 0),
255+
percentiles_request_ar=[(p, np.percentile(request_ars or 0, p))
256+
for p in selected_percentiles],
231257
)
232-
return metrics, actual_output_lens
258+
return metrics, actual_output_lens, request_ars
233259

234260

235261
async def benchmark(
@@ -403,7 +429,7 @@ async def limited_request_func(request_func_input, streaming, pbar,
403429
# Close the session
404430
await session.close()
405431

406-
metrics, actual_output_lens = calculate_metrics(
432+
metrics, actual_output_lens, request_ars = calculate_metrics(
407433
input_requests=input_requests,
408434
outputs=outputs,
409435
dur_s=benchmark_duration,
@@ -431,6 +457,10 @@ async def limited_request_func(request_func_input, streaming, pbar,
431457
metrics.total_token_throughput))
432458
print("{:<40} {:<10.2f}".format("User throughput (tok/s):",
433459
metrics.tput_user))
460+
print("{:<40} {:<10.4f}".format("Mean Request AR:",
461+
metrics.mean_request_ar))
462+
print("{:<40} {:<10.4f}".format("Median Request AR:",
463+
metrics.median_request_ar))
434464

435465
result = {
436466
"duration": benchmark_duration,
@@ -443,12 +473,16 @@ async def limited_request_func(request_func_input, streaming, pbar,
443473
"output_throughput": metrics.output_throughput,
444474
"total_token_throughput": metrics.total_token_throughput,
445475
"user_throughput": metrics.tput_user,
476+
"mean_request_ar": metrics.mean_request_ar,
477+
"median_request_ar": metrics.median_request_ar,
446478
"input_lens": [output.prompt_len for output in outputs],
447479
"output_lens": actual_output_lens,
448480
"ttfts": [output.ttft for output in outputs],
449481
"itls": [output.itl for output in outputs],
450482
"generated_texts": [output.generated_text for output in outputs],
451483
"errors": [output.error for output in outputs],
484+
"request_ars": request_ars,
485+
"decode_iterations": [output.decode_iteration for output in outputs],
452486
}
453487

454488
def process_one_metric(
@@ -534,11 +568,15 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
534568
metrics = [
535569
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
536570
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
537-
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
571+
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms",
572+
"mean_request_ar", "median_request_ar", "std_request_ar"
538573
]
539574
# These raw data might be useful, but they are rather big. They can be added
540575
# later if needed
541-
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
576+
ignored_metrics = [
577+
"ttfts", "itls", "generated_texts", "errors", "request_ars",
578+
"decode_iterations"
579+
]
542580
pt_records = convert_to_pytorch_benchmark_format(
543581
args=args,
544582
metrics={k: [results[k]]
@@ -762,7 +800,8 @@ def main(args: argparse.Namespace):
762800
# Remove fields with too many data points
763801
for field in [
764802
"input_lens", "output_lens", "ttfts", "itls",
765-
"generated_texts", "errors"
803+
"generated_texts", "errors", "request_ars",
804+
"decode_iterations"
766805
]:
767806
if field in result_json:
768807
del result_json[field]
@@ -963,11 +1002,11 @@ def main(args: argparse.Namespace):
9631002
parser.add_argument(
9641003
"--percentile-metrics",
9651004
type=str,
966-
default="ttft,tpot,itl",
1005+
default="ttft,tpot,itl,request_ar",
9671006
help="Comma-separated list of selected metrics to report percentils. "
9681007
"This argument specifies the metrics to report percentiles. "
969-
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
970-
"Default value is \"ttft,tpot,itl\".")
1008+
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\", \"request_ar\". "
1009+
"Default value is \"ttft,tpot,itl,request_ar\".")
9711010
parser.add_argument(
9721011
"--metric-percentiles",
9731012
type=str,

0 commit comments

Comments
 (0)