Skip to content

[TRTLLM-6685][feat] Add speculative metrics for trt llm bench #6476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 117 additions & 29 deletions tensorrt_llm/bench/dataclasses/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def register_request_perf_item(self, request_perf_item: PerfItemTuple):
if request_perf_item.response_is_final:
self.num_complete = self.num_complete + 1

def generate_statistics_summary(self) -> None:
def generate_statistics_summary(self, max_draft_tokens: int) -> None:
"""Generate summary statistics from internally stored statistics.

Returns:
Expand All @@ -90,42 +90,62 @@ def generate_statistics_summary(self) -> None:

intertoken_avg_latencies = []
output_tokens = []
request_acceptance = []
total_decoding_iterations = 0
ttft_times = []
last_queue_time = 0.0
queue_time_total = 0.0

num_draft_tokens = []
num_accepted_draft_tokens = []
draft_acceptance_rate = []
acceptance_length = []

for entry in self.requests.values():
start_time = min(entry.start_timestamp, start_time)
end_time = max(entry.end_timestamp, end_time)
last_queue_time = max(entry.start_timestamp, last_queue_time)
request_ar = entry.num_generated_tokens / (entry.decode_iteration +
1)

request_latencies.append(entry.end_to_end_latency)
generation_latencies.append(entry.generation_time)
generation_throughputs.append(entry.generation_token_throughput)
ttft_times.append(entry.time_to_first_token)
intertoken_avg_latencies.append(entry.intertoken_latency)
request_acceptance.append(request_ar)
output_throughput_per_user.append(entry.output_token_throughput)
total_decoding_iterations += entry.decode_iteration + 1

output_tokens.append(entry.num_total_output_tokens)
total_input_tokens += entry.num_input_tokens

global_acceptance_rate = sum(output_tokens) / total_decoding_iterations
# For speculative decoding, we need to track the number of draft tokens per request and the number of accepted draft tokens per request
if max_draft_tokens > 0:
num_draft_tokens.append(max_draft_tokens *
(entry.decode_iteration + 1))
num_accepted_draft_tokens.append(entry.num_total_output_tokens -
entry.decode_iteration - 1)
draft_acceptance_rate.append(
float(num_accepted_draft_tokens[-1]) /
float(num_draft_tokens[-1]))
acceptance_length.append(entry.num_total_output_tokens /
(entry.decode_iteration + 1))

global_acceptance_length = sum(
output_tokens) / total_decoding_iterations
queue_time_total = last_queue_time - start_time
percentile_request_accept = PercentileStats.from_iterable(
request_acceptance) if request_acceptance else None

num_draft_tokens_percentiles = PercentileStats.from_iterable(
num_draft_tokens) if num_draft_tokens else None
num_accepted_draft_tokens_percentiles = PercentileStats.from_iterable(
num_accepted_draft_tokens) if num_accepted_draft_tokens else None
draft_acceptance_rate_percentiles = PercentileStats.from_iterable(
draft_acceptance_rate) if draft_acceptance_rate else None
acceptance_length_percentiles = PercentileStats.from_iterable(
acceptance_length) if acceptance_length else None

stats = BenchmarkStatistics(
num_requests=num_requests,
total_latency_ns=end_time - start_time,
total_output_tokens=sum(output_tokens),
total_input_tokens=total_input_tokens,
acceptance_rate=global_acceptance_rate,
request_latency_percentiles=PercentileStats.from_iterable(
request_latencies),
tpot_percentiles=PercentileStats.from_iterable(
Expand All @@ -139,7 +159,12 @@ def generate_statistics_summary(self) -> None:
generation_latencies),
token_percentiles=PercentileStats.from_iterable(output_tokens),
issue_rate_ns=queue_time_total / num_requests,
acceptance_percentiles=percentile_request_accept,
acceptance_length=global_acceptance_length,
num_draft_tokens_percentiles=num_draft_tokens_percentiles,
num_accepted_draft_tokens_percentiles=
num_accepted_draft_tokens_percentiles,
draft_acceptance_rate_percentiles=draft_acceptance_rate_percentiles,
acceptance_length_percentiles=acceptance_length_percentiles,
)

return stats
Expand All @@ -164,12 +189,13 @@ def __init__(self,
logger (Logger): A logger for logging.
streaming (bool, optional): Streaming benchmark used. Defaults to False.
"""
self.raw_statistics = statistics
self.statistics = statistics.generate_statistics_summary()
self.dataset_metadata = dataset_metadata
self.rt_cfg = rt_cfg
self.logger = logger
self.kwargs = kwargs
self.raw_statistics = statistics
self.statistics = statistics.generate_statistics_summary(
self.get_max_draft_len())
self.streaming = streaming

@staticmethod
Expand Down Expand Up @@ -415,9 +441,22 @@ def get_statistics_dict(self) -> Dict[str, Any]:
stats_dict["decoding_stats"] = {
"mode":
decoding_mode,
"acceptance_percentiles":
self.statistics.acceptance_percentiles.model_dump(
"num_draft_tokens_percentiles":
self.statistics.num_draft_tokens_percentiles.model_dump(
exclude_none=True, by_alias=True, mode='json')
if self.statistics.num_draft_tokens_percentiles else None,
"num_accepted_draft_tokens_percentiles":
self.statistics.num_accepted_draft_tokens_percentiles.
model_dump(exclude_none=True, by_alias=True, mode='json') if
self.statistics.num_accepted_draft_tokens_percentiles else None,
"draft_acceptance_rate_percentiles":
self.statistics.draft_acceptance_rate_percentiles.model_dump(
exclude_none=True, by_alias=True, mode='json')
if self.statistics.draft_acceptance_rate_percentiles else None,
"acceptance_length_percentiles":
self.statistics.acceptance_length_percentiles.model_dump(
exclude_none=True, by_alias=True, mode='json')
if self.statistics.acceptance_length_percentiles else None
}
# Dataset metadata
stats_dict["dataset"] = self.dataset_metadata.model_dump(by_alias=True,
Expand Down Expand Up @@ -557,21 +596,61 @@ def report_statistics(self) -> None:
decoding_stats = ""
if decoding is not None:
decoding = stats_dict["decoding_stats"]
acc = decoding["acceptance_percentiles"]
acc_stats = "\n".join(
f"[AR] {key.upper():<7}: {acc[key]:.2f}" for key in
["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])

decoding_stats = (
"===========================================================\n"
f"= DECODING STATISTICS ({decoding['mode']})\n"
"===========================================================\n"
"\n"
"-- Acceptance Rate Details --------------------------------\n\n"
"\n"
f"{acc_stats}"
f"\n"
"===========================================================\n")
if self.get_max_draft_len() > 0:
num_draft_tokens = decoding["num_draft_tokens_percentiles"]
num_draft_tokens_stats = "\n".join(
f"[DT] {key.upper():<7}: {num_draft_tokens[key]:.2f}"
for key in [
"minimum", "maximum", "average", "p50", "p90", "p95",
"p99"
])

num_accepted_draft_tokens = decoding[
"num_accepted_draft_tokens_percentiles"]
num_accepted_draft_tokens_stats = "\n".join(
f"[ADT] {key.upper():<7}: {num_accepted_draft_tokens[key]:.2f}"
for key in [
"minimum", "maximum", "average", "p50", "p90", "p95",
"p99"
])

draft_acceptance_rate = decoding[
"draft_acceptance_rate_percentiles"]
draft_acceptance_rate_stats = "\n".join(
f"[DAR] {key.upper():<7}: {draft_acceptance_rate[key]:.2f}"
for key in [
"minimum", "maximum", "average", "p50", "p90", "p95",
"p99"
])

acceptance_length = decoding["acceptance_length_percentiles"]
acceptance_length_stats = "\n".join(
f"[AL] {key.upper():<7}: {acceptance_length[key]:.2f}"
for key in [
"minimum", "maximum", "average", "p50", "p90", "p95",
"p99"
])

decoding_stats = (
"===========================================================\n"
f"= DECODING STATISTICS ({decoding['mode']})\n"
"===========================================================\n"
"\n"
"-- Number of Draft Tokens Details --------------------------------\n\n"
"\n"
f"{num_draft_tokens_stats}"
f"\n"
"-- Number of Accepted Draft Tokens Details --------------------------------\n\n"
f"{num_accepted_draft_tokens_stats}"
f"\n"
"-- Draft Acceptance Rate Details --------------------------------\n\n"
f"{draft_acceptance_rate_stats}"
f"\n"
"-- Acceptance Length Details --------------------------------\n\n"
f"{acceptance_length_stats}"
f"\n"
"===========================================================\n"
)

logging_info = (f"{backend_info}"
f"{request_info}"
Expand All @@ -582,3 +661,12 @@ def report_statistics(self) -> None:
f"{self.dataset_metadata.get_summary_for_print()}")
self.logger.info(logging_info)
return self.statistics

def get_max_draft_len(self) -> int:
"""Get max_draft_len from speculative_config."""
# Try to get from speculative_config
if ("speculative_config" in self.kwargs
and self.kwargs["speculative_config"] is not None):
return self.kwargs["speculative_config"].max_draft_len or 0

return 0
8 changes: 6 additions & 2 deletions tensorrt_llm/bench/dataclasses/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class BenchmarkStatistics(BaseModel):
issue_rate_ns: float

# Speculative Information
acceptance_rate: float
acceptance_length: float

# Percentile-related Statistics
request_latency_percentiles: Optional[PercentileStats] = None
Expand All @@ -137,7 +137,11 @@ class BenchmarkStatistics(BaseModel):
ttft_percentiles: Optional[PercentileStats] = None
generation_tp_percentiles: Optional[PercentileStats] = None
generation_latency_percentiles: Optional[PercentileStats] = None
acceptance_percentiles: Optional[PercentileStats] = None
# Percentile-related Speculative Statistics
num_draft_tokens_percentiles: Optional[PercentileStats] = None
num_accepted_draft_tokens_percentiles: Optional[PercentileStats] = None
draft_acceptance_rate_percentiles: Optional[PercentileStats] = None
acceptance_length_percentiles: Optional[PercentileStats] = None

@computed_field
def sum_per_request_latencies_ns(self) -> float:
Expand Down