Skip to content

Commit d2d24a4

Browse files
committed
add speculative metrics
1 parent b3ca159 commit d2d24a4

File tree

2 files changed

+112
-20
lines changed

2 files changed

+112
-20
lines changed

tensorrt_llm/bench/dataclasses/reporting.py

Lines changed: 105 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def register_request_perf_item(self, request_perf_item: PerfItemTuple):
7272
if request_perf_item.response_is_final:
7373
self.num_complete = self.num_complete + 1
7474

75-
def generate_statistics_summary(self) -> None:
75+
def generate_statistics_summary(self, num_draft_tokens: int) -> None:
7676
"""Generate summary statistics from internally stored statistics.
7777
7878
Returns:
@@ -90,41 +90,59 @@ def generate_statistics_summary(self) -> None:
9090

9191
intertoken_avg_latencies = []
9292
output_tokens = []
93-
request_acceptance = []
9493
total_decoding_iterations = 0
9594
ttft_times = []
9695
last_queue_time = 0.0
9796
queue_time_total = 0.0
9897

98+
num_draft_tokens = []
99+
num_accepted_drafts = []
100+
draft_acceptance_rate = []
101+
acceptance_length = []
102+
99103
for entry in self.requests.values():
100104
start_time = min(entry.start_timestamp, start_time)
101105
end_time = max(entry.end_timestamp, end_time)
102106
last_queue_time = max(entry.start_timestamp, last_queue_time)
103-
request_ar = entry.num_generated_tokens / (entry.decode_iteration +
104-
1)
105107

106108
request_latencies.append(entry.end_to_end_latency)
107109
generation_latencies.append(entry.generation_time)
108110
generation_throughputs.append(entry.generation_token_throughput)
109111
ttft_times.append(entry.time_to_first_token)
110112
intertoken_avg_latencies.append(entry.intertoken_latency)
111-
request_acceptance.append(request_ar)
113+
112114
output_throughput_per_user.append(entry.output_token_throughput)
113115
total_decoding_iterations += entry.decode_iteration + 1
114116

115117
output_tokens.append(entry.num_total_output_tokens)
116118
total_input_tokens += entry.num_input_tokens
117119

118-
global_acceptance_rate = sum(output_tokens) / total_decoding_iterations
120+
# For speculative decoding, we need to track the number of draft tokens per request and the number of accepted draft tokens per request
121+
if num_draft_tokens > 0:
122+
num_draft_tokens.append(num_draft_tokens * (entry.decode_iteration + 1))
123+
num_accepted_draft_tokens.append(entry.num_generated_tokens - entry.decode_iteration - 1)
124+
draft_acceptance_rate.append(float(num_accepted_draft_tokens[-1]) / float(num_draft_tokens[-1]))
125+
acceptance_length.append(entry.num_generated_tokens / (entry.decode_iteration +
126+
1))
127+
128+
global_acceptance_length = sum(acceptance_length) / len(acceptance_length)
119129
queue_time_total = last_queue_time - start_time
120-
percentile_request_accept = PercentileStats.from_iterable(
121-
request_acceptance) if request_acceptance else None
130+
131+
num_draft_tokens_percentiles = PercentileStats.from_iterable(
132+
num_draft_tokens) if num_draft_tokens else None
133+
num_accepted_drafts_percentiles = PercentileStats.from_iterable(
134+
num_accepted_drafts) if num_accepted_drafts else None
135+
draft_acceptance_rate_percentiles = PercentileStats.from_iterable(
136+
draft_acceptance_rate) if draft_acceptance_rate else None
137+
acceptance_length_percentiles = PercentileStats.from_iterable(
138+
acceptance_length) if acceptance_length else None
122139

123140
stats = BenchmarkStatistics(
124141
num_requests=num_requests,
125142
total_latency_ns=end_time - start_time,
126143
total_output_tokens=sum(output_tokens),
127144
total_input_tokens=total_input_tokens,
145+
total_draft_tokens=sum(draft_tokens),
128146
acceptance_rate=global_acceptance_rate,
129147
request_latency_percentiles=PercentileStats.from_iterable(
130148
request_latencies),
@@ -139,7 +157,11 @@ def generate_statistics_summary(self) -> None:
139157
generation_latencies),
140158
token_percentiles=PercentileStats.from_iterable(output_tokens),
141159
issue_rate_ns=queue_time_total / num_requests,
142-
acceptance_percentiles=percentile_request_accept,
160+
acceptance_length=global_acceptance_length,
161+
num_draft_tokens_percentiles=num_draft_tokens_percentiles,
162+
num_accepted_drafts_percentiles=num_accepted_drafts_percentiles,
163+
draft_acceptance_rate_percentiles=draft_acceptance_rate_percentiles,
164+
acceptance_length_percentiles=acceptance_length_percentiles,
143165
)
144166

145167
return stats
@@ -171,6 +193,7 @@ def __init__(self,
171193
self.logger = logger
172194
self.kwargs = kwargs
173195
self.streaming = streaming
196+
174197

175198
@staticmethod
176199
def convert_to_ms(ns: float) -> float:
@@ -321,7 +344,9 @@ def get_statistics_dict(self) -> Dict[str, Any]:
321344
"avg_num_concurrent_requests":
322345
self.statistics.avg_concurrent_requests,
323346
"avg_input_length": self.statistics.average_input_length,
324-
"avg_output_length": self.statistics.average_output_length
347+
"avg_output_length": self.statistics.average_output_length,
348+
"total_draft_tokens": self.statistics.total_draft_tokens,
349+
"avg_draft_tokens_per_request": 0
325350
}
326351

327352
# Performance stats
@@ -415,9 +440,18 @@ def get_statistics_dict(self) -> Dict[str, Any]:
415440
stats_dict["decoding_stats"] = {
416441
"mode":
417442
decoding_mode,
418-
"acceptance_percentiles":
419-
self.statistics.acceptance_percentiles.model_dump(
420-
exclude_none=True, by_alias=True, mode='json')
443+
"draft_tokens_percentiles":
444+
self.statistics.num_draft_tokens_percentiles.model_dump(
445+
exclude_none=True, by_alias=True, mode='json') if self.statistics.num_draft_tokens_percentiles else None,
446+
"num_accepted_drafts_percentiles":
447+
self.statistics.num_accepted_drafts_percentiles.model_dump(
448+
exclude_none=True, by_alias=True, mode='json') if self.statistics.num_accepted_drafts_percentiles else None,
449+
"draft_acceptance_rate_percentiles":
450+
self.statistics.draft_acceptance_rate_percentiles.model_dump(
451+
exclude_none=True, by_alias=True, mode='json') if self.statistics.draft_acceptance_rate_percentiles else None,
452+
"acceptance_length_percentiles":
453+
self.statistics.acceptance_length_percentiles.model_dump(
454+
exclude_none=True, by_alias=True, mode='json') if self.statistics.acceptance_length_percentiles else None
421455
}
422456
# Dataset metadata
423457
stats_dict["dataset"] = self.dataset_metadata.model_dump(by_alias=True,
@@ -503,6 +537,8 @@ def report_statistics(self) -> None:
503537
f"Number of concurrent requests: {requests['avg_num_concurrent_requests']:.4f}\n"
504538
f"Average Input Length (tokens): {requests['avg_input_length']:.4f}\n"
505539
f"Average Output Length (tokens): {requests['avg_output_length']:.4f}\n"
540+
f"Total Draft Tokens: {requests['total_draft_tokens']}\n"
541+
f"Max Draft Tokens per Request: {requests['avg_draft_tokens_per_request']}\n"
506542
)
507543

508544
perf_header = (
@@ -557,19 +593,44 @@ def report_statistics(self) -> None:
557593
decoding_stats = ""
558594
if decoding is not None:
559595
decoding = stats_dict["decoding_stats"]
560-
acc = decoding["acceptance_percentiles"]
561-
acc_stats = "\n".join(
562-
f"[AR] {key.upper():<7}: {acc[key]:.2f}" for key in
596+
597+
num_draft_tokens = decoding["num_draft_tokens_percentiles"]
598+
num_draft_tokens_stats = "\n".join(
599+
f"[DT] {key.upper():<7}: {num_draft_tokens[key]:.2f}" for key in
600+
["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])
601+
602+
num_accepted_draft_tokens = decoding["num_accepted_drafts_percentiles"]
603+
num_accepted_draft_tokens_stats = "\n".join(
604+
f"[ADT] {key.upper():<7}: {num_accepted_draft_tokens[key]:.2f}" for key in
605+
["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])
606+
607+
draft_acceptance_rate = decoding["draft_acceptance_rate_percentiles"]
608+
draft_acceptance_rate_stats = "\n".join(
609+
f"[DAR] {key.upper():<7}: {draft_acceptance_rate[key]:.2f}" for key in
610+
["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])
611+
612+
acceptance_length = decoding["acceptance_length_percentiles"]
613+
acceptance_length_stats = "\n".join(
614+
f"[AL] {key.upper():<7}: {acceptance_length[key]:.2f}" for key in
563615
["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])
564616

565617
decoding_stats = (
566618
"===========================================================\n"
567619
f"= DECODING STATISTICS ({decoding['mode']})\n"
568620
"===========================================================\n"
569621
"\n"
570-
"-- Acceptance Rate Details --------------------------------\n\n"
622+
"-- Number of Draft Tokens Details --------------------------------\n\n"
571623
"\n"
572-
f"{acc_stats}"
624+
f"{num_draft_tokens_stats}"
625+
f"\n"
626+
"-- Number of Accepted Draft Tokens Details --------------------------------\n\n"
627+
f"{num_accepted_draft_tokens_stats}"
628+
f"\n"
629+
"-- Draft Acceptance Rate Details --------------------------------\n\n"
630+
f"{draft_acceptance_rate_stats}"
631+
f"\n"
632+
"-- Acceptance Length Details --------------------------------\n\n"
633+
f"{acceptance_length_stats}"
573634
f"\n"
574635
"===========================================================\n")
575636

@@ -582,3 +643,29 @@ def report_statistics(self) -> None:
582643
f"{self.dataset_metadata.get_summary_for_print()}")
583644
self.logger.info(logging_info)
584645
return self.statistics
646+
647+
def get_max_draft_len(self) -> int:
648+
"""Get max_draft_len from rt_cfg or speculative_config."""
649+
# Try to get from rt_cfg first
650+
if (self.rt_cfg.decoding_config
651+
and self.rt_cfg.decoding_config.decoding_mode
652+
!= SpeculativeDecodingMode.NONE):
653+
# For C++ backend, max_draft_len is stored in the engine config
654+
# We need to read it from the engine config file
655+
if self.rt_cfg.engine_dir:
656+
config_path = self.rt_cfg.engine_dir / "config.json"
657+
try:
658+
with open(config_path, "r") as config:
659+
engine_config = json.load(config)
660+
build_cfg = engine_config["build_config"]
661+
return build_cfg.get("max_draft_len", 0)
662+
except (FileNotFoundError, KeyError, json.JSONDecodeError):
663+
pass
664+
665+
# Try to get from speculative_config
666+
if ("speculative_config" in self.kwargs
667+
and self.kwargs["speculative_config"] is not None):
668+
return self.kwargs["speculative_config"].max_draft_len or 0
669+
670+
return 0
671+

tensorrt_llm/bench/dataclasses/statistics.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,14 @@ class BenchmarkStatistics(BaseModel):
121121
# Token-related Properties
122122
total_output_tokens: int
123123
total_input_tokens: int
124+
total_draft_tokens: int = 0
124125

125126
# General Information
126127
num_requests: int
127128
issue_rate_ns: float
128129

129130
# Speculative Information
130-
acceptance_rate: float
131+
acceptance_length: float
131132

132133
# Percentile-related Statistics
133134
request_latency_percentiles: Optional[PercentileStats] = None
@@ -137,7 +138,11 @@ class BenchmarkStatistics(BaseModel):
137138
ttft_percentiles: Optional[PercentileStats] = None
138139
generation_tp_percentiles: Optional[PercentileStats] = None
139140
generation_latency_percentiles: Optional[PercentileStats] = None
140-
acceptance_percentiles: Optional[PercentileStats] = None
141+
# Percentile-related Speculative Statistics
142+
num_draft_tokens_percentiles: Optional[PercentileStats] = None
143+
num_accepted_drafts_percentiles: Optional[PercentileStats] = None
144+
draft_acceptance_rate_percentiles: Optional[PercentileStats] = None
145+
acceptance_length_percentiles: Optional[PercentileStats] = None
141146

142147
@computed_field
143148
def sum_per_request_latencies_ns(self) -> float:

0 commit comments

Comments
 (0)