Skip to content

Commit d16af87

Browse files
authored
[TRTLLM-7158][feat] Introduce sampler options in trtllm bench (#6855)
Signed-off-by: Daniel Campora <[email protected]>
1 parent d1d17db commit d16af87

File tree

3 files changed

+73
-12
lines changed

3 files changed

+73
-12
lines changed

tensorrt_llm/bench/benchmark/low_latency.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
2626

2727
# isort: off
28-
from tensorrt_llm.bench.benchmark.utils.general import get_settings_from_engine, get_settings, ALL_SUPPORTED_BACKENDS
28+
from tensorrt_llm.bench.benchmark.utils.general import get_settings_from_engine, get_settings, update_sampler_args_with_extra_options, ALL_SUPPORTED_BACKENDS
2929
# isort: on
3030
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
3131
initialize_tokenizer,
@@ -135,6 +135,13 @@
135135
default=1,
136136
help="Number of search beams.",
137137
)
138+
@optgroup.option("--sampler_options",
139+
type=click.Path(exists=True,
140+
readable=True,
141+
path_type=Path,
142+
resolve_path=True),
143+
default=None,
144+
help="Path to a YAML file that sets sampler options.")
138145
@optgroup.option(
139146
"--concurrency",
140147
type=int,
@@ -326,12 +333,16 @@ def latency_command(
326333
eos_id = tokenizer.eos_token_id if not ignore_eos else -1
327334
pad_id = tokenizer.pad_token_id if not ignore_eos else -1
328335

329-
sampling_params = SamplingParams(
330-
end_id=eos_id,
331-
pad_id=pad_id,
332-
n=beam_width,
333-
use_beam_search=beam_width > 1,
334-
)
336+
sampler_args = {
337+
"end_id": eos_id,
338+
"pad_id": pad_id,
339+
"n": beam_width,
340+
"use_beam_search": beam_width > 1
341+
}
342+
sampler_args = update_sampler_args_with_extra_options(
343+
sampler_args, params.pop("sampler_options"))
344+
sampling_params = SamplingParams(**sampler_args)
345+
335346
post_proc_params = None # No detokenization
336347

337348
# Perform warmup if requested.

tensorrt_llm/bench/benchmark/throughput.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from tensorrt_llm import LLM as PyTorchLLM
2323
from tensorrt_llm._tensorrt_engine import LLM
2424
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
25-
from tensorrt_llm.bench.benchmark.utils.general import generate_warmup_dataset
25+
from tensorrt_llm.bench.benchmark.utils.general import (
26+
generate_warmup_dataset, update_sampler_args_with_extra_options)
2627
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
2728
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
2829
from tensorrt_llm.bench.dataclasses.reporting import ReportUtility
@@ -67,6 +68,13 @@
6768
help=
6869
"Path to a YAML file that overwrites the parameters specified by trtllm-bench."
6970
)
71+
@optgroup.option("--sampler_options",
72+
type=click.Path(exists=True,
73+
readable=True,
74+
path_type=Path,
75+
resolve_path=True),
76+
default=None,
77+
help="Path to a YAML file that sets sampler options.")
7078
@optgroup.option(
7179
"--max_batch_size",
7280
type=int,
@@ -455,10 +463,16 @@ def ignore_trt_only_args(kwargs: dict):
455463
else:
456464
llm = LLM(**kwargs)
457465

458-
sampling_params = SamplingParams(end_id=eos_id,
459-
pad_id=eos_id,
460-
n=beam_width,
461-
use_beam_search=beam_width > 1)
466+
sampler_args = {
467+
"end_id": eos_id,
468+
"pad_id": eos_id,
469+
"n": beam_width,
470+
"use_beam_search": beam_width > 1
471+
}
472+
sampler_args = update_sampler_args_with_extra_options(
473+
sampler_args, params.pop("sampler_options"))
474+
sampling_params = SamplingParams(**sampler_args)
475+
462476
post_proc_params = None # No detokenization
463477

464478
# Perform warmup if requested.

tensorrt_llm/bench/benchmark/utils/general.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,39 @@ def generate_warmup_dataset(requests, steps) -> List[InferenceRequest]:
199199
warm_up_dataset = choices(requests, k=steps)
200200
shuffle(warm_up_dataset)
201201
return warm_up_dataset
202+
203+
204+
def update_sampler_args_with_extra_options(sampler_args: Dict,
205+
sampler_options: str) -> Dict:
206+
"""Update sampler arguments with options from a YAML file.
207+
208+
Args:
209+
sampler_args: Base sampler arguments dictionary.
210+
sampler_options: Path to YAML file containing additional options.
211+
212+
Returns:
213+
Dict: Merged sampler arguments.
214+
215+
Raises:
216+
FileNotFoundError: If the YAML file doesn't exist.
217+
yaml.YAMLError: If the YAML file is malformed.
218+
TypeError: If the YAML content is not a dictionary.
219+
"""
220+
if sampler_options is not None:
221+
try:
222+
with open(sampler_options, 'r') as f:
223+
sampler_options_dict = yaml.safe_load(f)
224+
except FileNotFoundError:
225+
raise FileNotFoundError(
226+
f"Sampler options file not found: {sampler_options}")
227+
except yaml.YAMLError as e:
228+
raise yaml.YAMLError(
229+
f"Invalid YAML in sampler options file {sampler_options}: {e}")
230+
231+
if not isinstance(sampler_options_dict, dict):
232+
raise TypeError(
233+
f"Sampler options file {sampler_options} must contain a dictionary, "
234+
f"got {type(sampler_options_dict)}")
235+
236+
sampler_args = sampler_args | sampler_options_dict
237+
return sampler_args

0 commit comments

Comments
 (0)