Skip to content

Commit b8ba61d

Browse files
committed
add speculative metrics
Signed-off-by: linquanh <[email protected]>
1 parent 2fe9cc0 commit b8ba61d

File tree

2 files changed

+115
-32
lines changed

2 files changed

+115
-32
lines changed

tensorrt_llm/bench/dataclasses/reporting.py

Lines changed: 109 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def register_request_perf_item(self, request_perf_item: PerfItemTuple):
7272
if request_perf_item.response_is_final:
7373
self.num_complete = self.num_complete + 1
7474

75-
def generate_statistics_summary(self) -> None:
75+
def generate_statistics_summary(self, max_draft_tokens: int) -> None:
7676
"""Generate summary statistics from internally stored statistics.
7777
7878
Returns:
@@ -90,42 +90,57 @@ def generate_statistics_summary(self) -> None:
9090

9191
intertoken_avg_latencies = []
9292
output_tokens = []
93-
request_acceptance = []
9493
total_decoding_iterations = 0
9594
ttft_times = []
9695
last_queue_time = 0.0
9796
queue_time_total = 0.0
9897

98+
num_draft_tokens = []
99+
num_accepted_draft_tokens = []
100+
draft_acceptance_rate = []
101+
acceptance_length = []
102+
99103
for entry in self.requests.values():
100104
start_time = min(entry.start_timestamp, start_time)
101105
end_time = max(entry.end_timestamp, end_time)
102106
last_queue_time = max(entry.start_timestamp, last_queue_time)
103-
request_ar = entry.num_generated_tokens / (entry.decode_iteration +
104-
1)
105107

106108
request_latencies.append(entry.end_to_end_latency)
107109
generation_latencies.append(entry.generation_time)
108110
generation_throughputs.append(entry.generation_token_throughput)
109111
ttft_times.append(entry.time_to_first_token)
110112
intertoken_avg_latencies.append(entry.intertoken_latency)
111-
request_acceptance.append(request_ar)
112113
output_throughput_per_user.append(entry.output_token_throughput)
113114
total_decoding_iterations += entry.decode_iteration + 1
114115

115116
output_tokens.append(entry.num_total_output_tokens)
116117
total_input_tokens += entry.num_input_tokens
117118

118-
global_acceptance_rate = sum(output_tokens) / total_decoding_iterations
119+
# For speculative decoding, we need to track the number of draft tokens per request and the number of accepted draft tokens per request
120+
if max_draft_tokens > 0:
121+
num_draft_tokens.append(max_draft_tokens * (entry.decode_iteration + 1))
122+
num_accepted_draft_tokens.append(entry.num_generated_tokens - entry.decode_iteration - 1)
123+
draft_acceptance_rate.append(float(num_accepted_draft_tokens[-1]) / float(num_draft_tokens[-1]))
124+
acceptance_length.append(entry.num_generated_tokens / (entry.decode_iteration +
125+
1))
126+
127+
global_acceptance_length = sum(output_tokens) / total_decoding_iterations
119128
queue_time_total = last_queue_time - start_time
120-
percentile_request_accept = PercentileStats.from_iterable(
121-
request_acceptance) if request_acceptance else None
129+
130+
num_draft_tokens_percentiles = PercentileStats.from_iterable(
131+
num_draft_tokens) if num_draft_tokens else None
132+
num_accepted_draft_tokens_percentiles = PercentileStats.from_iterable(
133+
num_accepted_draft_tokens) if num_accepted_draft_tokens else None
134+
draft_acceptance_rate_percentiles = PercentileStats.from_iterable(
135+
draft_acceptance_rate) if draft_acceptance_rate else None
136+
acceptance_length_percentiles = PercentileStats.from_iterable(
137+
acceptance_length) if acceptance_length else None
122138

123139
stats = BenchmarkStatistics(
124140
num_requests=num_requests,
125141
total_latency_ns=end_time - start_time,
126142
total_output_tokens=sum(output_tokens),
127143
total_input_tokens=total_input_tokens,
128-
acceptance_rate=global_acceptance_rate,
129144
request_latency_percentiles=PercentileStats.from_iterable(
130145
request_latencies),
131146
tpot_percentiles=PercentileStats.from_iterable(
@@ -139,7 +154,11 @@ def generate_statistics_summary(self) -> None:
139154
generation_latencies),
140155
token_percentiles=PercentileStats.from_iterable(output_tokens),
141156
issue_rate_ns=queue_time_total / num_requests,
142-
acceptance_percentiles=percentile_request_accept,
157+
acceptance_length=global_acceptance_length,
158+
num_draft_tokens_percentiles=num_draft_tokens_percentiles,
159+
num_accepted_draft_tokens_percentiles=num_accepted_draft_tokens_percentiles,
160+
draft_acceptance_rate_percentiles=draft_acceptance_rate_percentiles,
161+
acceptance_length_percentiles=acceptance_length_percentiles,
143162
)
144163

145164
return stats
@@ -164,12 +183,12 @@ def __init__(self,
164183
logger (Logger): A logger for logging.
165184
streaming (bool, optional): Streaming benchmark used. Defaults to False.
166185
"""
167-
self.raw_statistics = statistics
168-
self.statistics = statistics.generate_statistics_summary()
169186
self.dataset_metadata = dataset_metadata
170187
self.rt_cfg = rt_cfg
171188
self.logger = logger
172189
self.kwargs = kwargs
190+
self.raw_statistics = statistics
191+
self.statistics = statistics.generate_statistics_summary(self.get_max_draft_len())
173192
self.streaming = streaming
174193

175194
@staticmethod
@@ -415,9 +434,18 @@ def get_statistics_dict(self) -> Dict[str, Any]:
415434
stats_dict["decoding_stats"] = {
416435
"mode":
417436
decoding_mode,
418-
"acceptance_percentiles":
419-
self.statistics.acceptance_percentiles.model_dump(
420-
exclude_none=True, by_alias=True, mode='json')
437+
"num_draft_tokens_percentiles":
438+
self.statistics.num_draft_tokens_percentiles.model_dump(
439+
exclude_none=True, by_alias=True, mode='json') if self.statistics.num_draft_tokens_percentiles else None,
440+
"num_accepted_draft_tokens_percentiles":
441+
self.statistics.num_accepted_draft_tokens_percentiles.model_dump(
442+
exclude_none=True, by_alias=True, mode='json') if self.statistics.num_accepted_draft_tokens_percentiles else None,
443+
"draft_acceptance_rate_percentiles":
444+
self.statistics.draft_acceptance_rate_percentiles.model_dump(
445+
exclude_none=True, by_alias=True, mode='json') if self.statistics.draft_acceptance_rate_percentiles else None,
446+
"acceptance_length_percentiles":
447+
self.statistics.acceptance_length_percentiles.model_dump(
448+
exclude_none=True, by_alias=True, mode='json') if self.statistics.acceptance_length_percentiles else None
421449
}
422450
# Dataset metadata
423451
stats_dict["dataset"] = self.dataset_metadata.model_dump(by_alias=True,
@@ -557,21 +585,46 @@ def report_statistics(self) -> None:
557585
decoding_stats = ""
558586
if decoding is not None:
559587
decoding = stats_dict["decoding_stats"]
560-
acc = decoding["acceptance_percentiles"]
561-
acc_stats = "\n".join(
562-
f"[AR] {key.upper():<7}: {acc[key]:.2f}" for key in
563-
["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])
564-
565-
decoding_stats = (
566-
"===========================================================\n"
567-
f"= DECODING STATISTICS ({decoding['mode']})\n"
568-
"===========================================================\n"
569-
"\n"
570-
"-- Acceptance Rate Details --------------------------------\n\n"
571-
"\n"
572-
f"{acc_stats}"
573-
f"\n"
574-
"===========================================================\n")
588+
if self.get_max_draft_len() > 0:
589+
num_draft_tokens = decoding["num_draft_tokens_percentiles"]
590+
num_draft_tokens_stats = "\n".join(
591+
f"[DT] {key.upper():<7}: {num_draft_tokens[key]:.2f}" for key in
592+
["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])
593+
594+
num_accepted_draft_tokens = decoding["num_accepted_draft_tokens_percentiles"]
595+
num_accepted_draft_tokens_stats = "\n".join(
596+
f"[ADT] {key.upper():<7}: {num_accepted_draft_tokens[key]:.2f}" for key in
597+
["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])
598+
599+
draft_acceptance_rate = decoding["draft_acceptance_rate_percentiles"]
600+
draft_acceptance_rate_stats = "\n".join(
601+
f"[DAR] {key.upper():<7}: {draft_acceptance_rate[key]:.2f}" for key in
602+
["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])
603+
604+
acceptance_length = decoding["acceptance_length_percentiles"]
605+
acceptance_length_stats = "\n".join(
606+
f"[AL] {key.upper():<7}: {acceptance_length[key]:.2f}" for key in
607+
["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])
608+
609+
decoding_stats = (
610+
"===========================================================\n"
611+
f"= DECODING STATISTICS ({decoding['mode']})\n"
612+
"===========================================================\n"
613+
"\n"
614+
"-- Number of Draft Tokens Details --------------------------------\n\n"
615+
"\n"
616+
f"{num_draft_tokens_stats}"
617+
f"\n"
618+
"-- Number of Accepted Draft Tokens Details --------------------------------\n\n"
619+
f"{num_accepted_draft_tokens_stats}"
620+
f"\n"
621+
"-- Draft Acceptance Rate Details --------------------------------\n\n"
622+
f"{draft_acceptance_rate_stats}"
623+
f"\n"
624+
"-- Acceptance Length Details --------------------------------\n\n"
625+
f"{acceptance_length_stats}"
626+
f"\n"
627+
"===========================================================\n")
575628

576629
logging_info = (f"{backend_info}"
577630
f"{request_info}"
@@ -582,3 +635,29 @@ def report_statistics(self) -> None:
582635
f"{self.dataset_metadata.get_summary_for_print()}")
583636
self.logger.info(logging_info)
584637
return self.statistics
638+
639+
def get_max_draft_len(self) -> int:
640+
"""Get max_draft_len from rt_cfg or speculative_config."""
641+
# Try to get from rt_cfg first
642+
if (self.rt_cfg.decoding_config
643+
and self.rt_cfg.decoding_config.decoding_mode
644+
!= SpeculativeDecodingMode.NONE):
645+
# For C++ backend, max_draft_len is stored in the engine config
646+
# We need to read it from the engine config file
647+
if self.rt_cfg.engine_dir:
648+
config_path = self.rt_cfg.engine_dir / "config.json"
649+
try:
650+
with open(config_path, "r") as config:
651+
engine_config = json.load(config)
652+
build_cfg = engine_config["build_config"]
653+
return build_cfg.get("max_draft_len", 0)
654+
except (FileNotFoundError, KeyError, json.JSONDecodeError):
655+
pass
656+
657+
# Try to get from speculative_config
658+
if ("speculative_config" in self.kwargs
659+
and self.kwargs["speculative_config"] is not None):
660+
return self.kwargs["speculative_config"].max_draft_len or 0
661+
662+
return 0
663+

tensorrt_llm/bench/dataclasses/statistics.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class BenchmarkStatistics(BaseModel):
127127
issue_rate_ns: float
128128

129129
# Speculative Information
130-
acceptance_rate: float
130+
acceptance_length: float
131131

132132
# Percentile-related Statistics
133133
request_latency_percentiles: Optional[PercentileStats] = None
@@ -137,7 +137,11 @@ class BenchmarkStatistics(BaseModel):
137137
ttft_percentiles: Optional[PercentileStats] = None
138138
generation_tp_percentiles: Optional[PercentileStats] = None
139139
generation_latency_percentiles: Optional[PercentileStats] = None
140-
acceptance_percentiles: Optional[PercentileStats] = None
140+
# Percentile-related Speculative Statistics
141+
num_draft_tokens_percentiles: Optional[PercentileStats] = None
142+
num_accepted_draft_tokens_percentiles: Optional[PercentileStats] = None
143+
draft_acceptance_rate_percentiles: Optional[PercentileStats] = None
144+
acceptance_length_percentiles: Optional[PercentileStats] = None
141145

142146
@computed_field
143147
def sum_per_request_latencies_ns(self) -> float:

0 commit comments

Comments
 (0)