Skip to content

Commit 3570ad7

Browse files
kris1025lancelly
authored andcommitted
[TRTLLM-6685][feat] Add speculative metrics for trt llm bench (NVIDIA#6476)
Signed-off-by: linquanh <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 33905cd commit 3570ad7

File tree

2 files changed

+123
-31
lines changed

2 files changed

+123
-31
lines changed

tensorrt_llm/bench/dataclasses/reporting.py

Lines changed: 117 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def register_request_perf_item(self, request_perf_item: PerfItemTuple):
7272
if request_perf_item.response_is_final:
7373
self.num_complete = self.num_complete + 1
7474

75-
def generate_statistics_summary(self) -> None:
75+
def generate_statistics_summary(self, max_draft_tokens: int) -> None:
7676
"""Generate summary statistics from internally stored statistics.
7777
7878
Returns:
@@ -90,42 +90,62 @@ def generate_statistics_summary(self) -> None:
9090

9191
intertoken_avg_latencies = []
9292
output_tokens = []
93-
request_acceptance = []
9493
total_decoding_iterations = 0
9594
ttft_times = []
9695
last_queue_time = 0.0
9796
queue_time_total = 0.0
9897

98+
num_draft_tokens = []
99+
num_accepted_draft_tokens = []
100+
draft_acceptance_rate = []
101+
acceptance_length = []
102+
99103
for entry in self.requests.values():
100104
start_time = min(entry.start_timestamp, start_time)
101105
end_time = max(entry.end_timestamp, end_time)
102106
last_queue_time = max(entry.start_timestamp, last_queue_time)
103-
request_ar = entry.num_generated_tokens / (entry.decode_iteration +
104-
1)
105107

106108
request_latencies.append(entry.end_to_end_latency)
107109
generation_latencies.append(entry.generation_time)
108110
generation_throughputs.append(entry.generation_token_throughput)
109111
ttft_times.append(entry.time_to_first_token)
110112
intertoken_avg_latencies.append(entry.intertoken_latency)
111-
request_acceptance.append(request_ar)
112113
output_throughput_per_user.append(entry.output_token_throughput)
113114
total_decoding_iterations += entry.decode_iteration + 1
114115

115116
output_tokens.append(entry.num_total_output_tokens)
116117
total_input_tokens += entry.num_input_tokens
117118

118-
global_acceptance_rate = sum(output_tokens) / total_decoding_iterations
119+
# For speculative decoding, we need to track the number of draft tokens per request and the number of accepted draft tokens per request
120+
if max_draft_tokens > 0:
121+
num_draft_tokens.append(max_draft_tokens *
122+
(entry.decode_iteration + 1))
123+
num_accepted_draft_tokens.append(entry.num_total_output_tokens -
124+
entry.decode_iteration - 1)
125+
draft_acceptance_rate.append(
126+
float(num_accepted_draft_tokens[-1]) /
127+
float(num_draft_tokens[-1]))
128+
acceptance_length.append(entry.num_total_output_tokens /
129+
(entry.decode_iteration + 1))
130+
131+
global_acceptance_length = sum(
132+
output_tokens) / total_decoding_iterations
119133
queue_time_total = last_queue_time - start_time
120-
percentile_request_accept = PercentileStats.from_iterable(
121-
request_acceptance) if request_acceptance else None
134+
135+
num_draft_tokens_percentiles = PercentileStats.from_iterable(
136+
num_draft_tokens) if num_draft_tokens else None
137+
num_accepted_draft_tokens_percentiles = PercentileStats.from_iterable(
138+
num_accepted_draft_tokens) if num_accepted_draft_tokens else None
139+
draft_acceptance_rate_percentiles = PercentileStats.from_iterable(
140+
draft_acceptance_rate) if draft_acceptance_rate else None
141+
acceptance_length_percentiles = PercentileStats.from_iterable(
142+
acceptance_length) if acceptance_length else None
122143

123144
stats = BenchmarkStatistics(
124145
num_requests=num_requests,
125146
total_latency_ns=end_time - start_time,
126147
total_output_tokens=sum(output_tokens),
127148
total_input_tokens=total_input_tokens,
128-
acceptance_rate=global_acceptance_rate,
129149
request_latency_percentiles=PercentileStats.from_iterable(
130150
request_latencies),
131151
tpot_percentiles=PercentileStats.from_iterable(
@@ -139,7 +159,12 @@ def generate_statistics_summary(self) -> None:
139159
generation_latencies),
140160
token_percentiles=PercentileStats.from_iterable(output_tokens),
141161
issue_rate_ns=queue_time_total / num_requests,
142-
acceptance_percentiles=percentile_request_accept,
162+
acceptance_length=global_acceptance_length,
163+
num_draft_tokens_percentiles=num_draft_tokens_percentiles,
164+
num_accepted_draft_tokens_percentiles=
165+
num_accepted_draft_tokens_percentiles,
166+
draft_acceptance_rate_percentiles=draft_acceptance_rate_percentiles,
167+
acceptance_length_percentiles=acceptance_length_percentiles,
143168
)
144169

145170
return stats
@@ -164,12 +189,13 @@ def __init__(self,
164189
logger (Logger): A logger for logging.
165190
streaming (bool, optional): Streaming benchmark used. Defaults to False.
166191
"""
167-
self.raw_statistics = statistics
168-
self.statistics = statistics.generate_statistics_summary()
169192
self.dataset_metadata = dataset_metadata
170193
self.rt_cfg = rt_cfg
171194
self.logger = logger
172195
self.kwargs = kwargs
196+
self.raw_statistics = statistics
197+
self.statistics = statistics.generate_statistics_summary(
198+
self.get_max_draft_len())
173199
self.streaming = streaming
174200

175201
@staticmethod
@@ -415,9 +441,22 @@ def get_statistics_dict(self) -> Dict[str, Any]:
415441
stats_dict["decoding_stats"] = {
416442
"mode":
417443
decoding_mode,
418-
"acceptance_percentiles":
419-
self.statistics.acceptance_percentiles.model_dump(
444+
"num_draft_tokens_percentiles":
445+
self.statistics.num_draft_tokens_percentiles.model_dump(
446+
exclude_none=True, by_alias=True, mode='json')
447+
if self.statistics.num_draft_tokens_percentiles else None,
448+
"num_accepted_draft_tokens_percentiles":
449+
self.statistics.num_accepted_draft_tokens_percentiles.
450+
model_dump(exclude_none=True, by_alias=True, mode='json') if
451+
self.statistics.num_accepted_draft_tokens_percentiles else None,
452+
"draft_acceptance_rate_percentiles":
453+
self.statistics.draft_acceptance_rate_percentiles.model_dump(
454+
exclude_none=True, by_alias=True, mode='json')
455+
if self.statistics.draft_acceptance_rate_percentiles else None,
456+
"acceptance_length_percentiles":
457+
self.statistics.acceptance_length_percentiles.model_dump(
420458
exclude_none=True, by_alias=True, mode='json')
459+
if self.statistics.acceptance_length_percentiles else None
421460
}
422461
# Dataset metadata
423462
stats_dict["dataset"] = self.dataset_metadata.model_dump(by_alias=True,
@@ -557,21 +596,61 @@ def report_statistics(self) -> None:
557596
decoding_stats = ""
558597
if decoding is not None:
559598
decoding = stats_dict["decoding_stats"]
560-
acc = decoding["acceptance_percentiles"]
561-
acc_stats = "\n".join(
562-
f"[AR] {key.upper():<7}: {acc[key]:.2f}" for key in
563-
["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])
564-
565-
decoding_stats = (
566-
"===========================================================\n"
567-
f"= DECODING STATISTICS ({decoding['mode']})\n"
568-
"===========================================================\n"
569-
"\n"
570-
"-- Acceptance Rate Details --------------------------------\n\n"
571-
"\n"
572-
f"{acc_stats}"
573-
f"\n"
574-
"===========================================================\n")
599+
if self.get_max_draft_len() > 0:
600+
num_draft_tokens = decoding["num_draft_tokens_percentiles"]
601+
num_draft_tokens_stats = "\n".join(
602+
f"[DT] {key.upper():<7}: {num_draft_tokens[key]:.2f}"
603+
for key in [
604+
"minimum", "maximum", "average", "p50", "p90", "p95",
605+
"p99"
606+
])
607+
608+
num_accepted_draft_tokens = decoding[
609+
"num_accepted_draft_tokens_percentiles"]
610+
num_accepted_draft_tokens_stats = "\n".join(
611+
f"[ADT] {key.upper():<7}: {num_accepted_draft_tokens[key]:.2f}"
612+
for key in [
613+
"minimum", "maximum", "average", "p50", "p90", "p95",
614+
"p99"
615+
])
616+
617+
draft_acceptance_rate = decoding[
618+
"draft_acceptance_rate_percentiles"]
619+
draft_acceptance_rate_stats = "\n".join(
620+
f"[DAR] {key.upper():<7}: {draft_acceptance_rate[key]:.2f}"
621+
for key in [
622+
"minimum", "maximum", "average", "p50", "p90", "p95",
623+
"p99"
624+
])
625+
626+
acceptance_length = decoding["acceptance_length_percentiles"]
627+
acceptance_length_stats = "\n".join(
628+
f"[AL] {key.upper():<7}: {acceptance_length[key]:.2f}"
629+
for key in [
630+
"minimum", "maximum", "average", "p50", "p90", "p95",
631+
"p99"
632+
])
633+
634+
decoding_stats = (
635+
"===========================================================\n"
636+
f"= DECODING STATISTICS ({decoding['mode']})\n"
637+
"===========================================================\n"
638+
"\n"
639+
"-- Number of Draft Tokens Details --------------------------------\n\n"
640+
"\n"
641+
f"{num_draft_tokens_stats}"
642+
f"\n"
643+
"-- Number of Accepted Draft Tokens Details --------------------------------\n\n"
644+
f"{num_accepted_draft_tokens_stats}"
645+
f"\n"
646+
"-- Draft Acceptance Rate Details --------------------------------\n\n"
647+
f"{draft_acceptance_rate_stats}"
648+
f"\n"
649+
"-- Acceptance Length Details --------------------------------\n\n"
650+
f"{acceptance_length_stats}"
651+
f"\n"
652+
"===========================================================\n"
653+
)
575654

576655
logging_info = (f"{backend_info}"
577656
f"{request_info}"
@@ -582,3 +661,12 @@ def report_statistics(self) -> None:
582661
f"{self.dataset_metadata.get_summary_for_print()}")
583662
self.logger.info(logging_info)
584663
return self.statistics
664+
665+
def get_max_draft_len(self) -> int:
666+
"""Get max_draft_len from speculative_config."""
667+
# Try to get from speculative_config
668+
if ("speculative_config" in self.kwargs
669+
and self.kwargs["speculative_config"] is not None):
670+
return self.kwargs["speculative_config"].max_draft_len or 0
671+
672+
return 0

tensorrt_llm/bench/dataclasses/statistics.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class BenchmarkStatistics(BaseModel):
127127
issue_rate_ns: float
128128

129129
# Speculative Information
130-
acceptance_rate: float
130+
acceptance_length: float
131131

132132
# Percentile-related Statistics
133133
request_latency_percentiles: Optional[PercentileStats] = None
@@ -137,7 +137,11 @@ class BenchmarkStatistics(BaseModel):
137137
ttft_percentiles: Optional[PercentileStats] = None
138138
generation_tp_percentiles: Optional[PercentileStats] = None
139139
generation_latency_percentiles: Optional[PercentileStats] = None
140-
acceptance_percentiles: Optional[PercentileStats] = None
140+
# Percentile-related Speculative Statistics
141+
num_draft_tokens_percentiles: Optional[PercentileStats] = None
142+
num_accepted_draft_tokens_percentiles: Optional[PercentileStats] = None
143+
draft_acceptance_rate_percentiles: Optional[PercentileStats] = None
144+
acceptance_length_percentiles: Optional[PercentileStats] = None
141145

142146
@computed_field
143147
def sum_per_request_latencies_ns(self) -> float:

0 commit comments

Comments
 (0)