1919import sys
2020import tempfile
2121from pathlib import Path
22- from typing import Optional , Tuple , Union
22+ from typing import Any , Optional , Tuple , Union
2323
2424import pytest
2525import yaml
@@ -434,6 +434,148 @@ def test_qwen_e2e_cpprunner_large_new_tokens(model_name, model_path, llm_venv,
434434 ), f"Found zero length in sequence_lengths tensor: { seq_lengths } "
435435
436436
437+ # TODO replace the trtllm_bench_prolog
438+ class BenchRunner :
439+
440+ def __init__ (self ,
441+ llm_root : str ,
442+ llm_venv : Any ,
443+ model_subdir : str ,
444+ model_name : str ,
445+ streaming : bool ,
446+ tp_size : int ,
447+ use_pytorch_backend : bool = False ,
448+ skip_engine_build : bool = False ,
449+ quant : Optional [str ] = None ,
450+ extra_llm_api_options : Optional [str ] = None ,
451+ use_mpirun : bool = False ):
452+
453+ llm_models = llm_models_root ()
454+ assert llm_models is not None
455+ self .llm_root = llm_root
456+ self .llm_venv = llm_venv
457+ self .model_path = Path (llm_models , model_subdir ).absolute ()
458+ self .model_name = model_name
459+ self .quant = quant
460+ self .streaming = streaming
461+ self .skip_engine_build = skip_engine_build
462+ self .use_pytorch_backend = use_pytorch_backend
463+ self .use_mpirun = use_mpirun
464+ self .tp_size = tp_size
465+ self .quant_name = self .quant if self .quant is not None else "FP16"
466+ self .extra_llm_api_options = extra_llm_api_options
467+
468+ self .work_dir = Path (tempfile .TemporaryDirectory ().name )
469+
470+ self .dataset_path = os .path .join (self .work_dir , f"data.txt" )
471+ if self .use_mpirun :
472+ self .mpirun_cmd = f"mpirun --allow-run-as-root -n { self .tp_size } trtllm-llmapi-launch"
473+ else :
474+ self .mpirun_cmd = ""
475+ self .engine_path = None
476+
477+ def __call__ (self ):
478+ self .prepare_dataset ()
479+ if not (self .skip_engine_build or self .use_pytorch_backend ):
480+ self .build_engine ()
481+ self .run_bench ()
482+
483+ def prepare_dataset (self ):
484+ dataset_tool = Path (self .llm_root , "benchmarks" , "cpp" ,
485+ "prepare_dataset.py" )
486+
487+ # Generate a small dataset to run a test.
488+ self .work_dir .mkdir (parents = True )
489+ command = [
490+ f"{ dataset_tool .resolve ()} " ,
491+ "--stdout" ,
492+ "--tokenizer" ,
493+ f"{ self .model_path } " ,
494+ "token-norm-dist" ,
495+ "--input-mean" ,
496+ "128" ,
497+ "--output-mean" ,
498+ "128" ,
499+ "--input-stdev" ,
500+ "0" ,
501+ "--output-stdev" ,
502+ "0" ,
503+ "--num-requests" ,
504+ "10" ,
505+ ]
506+ print (f"Running command: { ' ' .join (command )} " )
507+ dataset_output = self .llm_venv .run_cmd (
508+ command ,
509+ caller = check_output ,
510+ )
511+ # Grab the stdout and write it to a dataset file for passing to suite.
512+ with open (self .dataset_path , "w" ) as dataset :
513+ dataset .write (dataset_output )
514+
515+ def build_engine (self ):
516+ if self .skip_engine_build :
517+ return
518+
519+ build_cmd = \
520+ f"{ self .mpirun_cmd } " \
521+ f"trtllm-bench " \
522+ f"--model { self .model_name } " \
523+ f"--model_path { self .model_path } " \
524+ f"--workspace { self .work_dir } " \
525+ f"build --tp_size { self .tp_size } "
526+
527+ if self .quant is not None :
528+ build_cmd = f"{ build_cmd } --quantization { self .quant } "
529+
530+ build_cmd = f"{ build_cmd } --dataset { self .dataset_path } "
531+ build_output = check_output (build_cmd ,
532+ shell = True ,
533+ env = self .llm_venv ._new_env )
534+
535+ for line in build_output .split ("\n " )[::- 1 ]:
536+ if line .startswith ("ENGINE SAVED:" ):
537+ self .engine_path = Path (line .split (":" )[1 ])
538+ break
539+
540+ def run_bench (self ):
541+ streaming = "--streaming" if self .streaming else ""
542+ benchmark_cmd = \
543+ f"{ self .mpirun_cmd } " \
544+ f"trtllm-bench --model { self .model_name } --model_path { self .model_path } " \
545+ f"throughput " \
546+ f"--tp { self .tp_size } "
547+ if self .engine_path :
548+ benchmark_cmd += f"--engine_dir { self .engine_path } "
549+ benchmark_cmd += f" --dataset { self .dataset_path } { streaming } "
550+
551+ if self .use_pytorch_backend :
552+ benchmark_cmd += " --backend pytorch"
553+
554+ if self .extra_llm_api_options :
555+ benchmark_cmd += f" --extra_llm_api_options { self .extra_llm_api_options } "
556+ check_call (benchmark_cmd , shell = True , env = self .llm_venv ._new_env )
557+
558+
559+ @pytest .mark .parametrize ("model_name" , ["meta-llama/Meta-Llama-3-8B-Instruct" ],
560+ ids = ["llama3-8b" ])
561+ @pytest .mark .parametrize ("model_subdir" ,
562+ ["llama-models-v3/llama-v3-8b-instruct-hf" ],
563+ ids = ["llama-v3" ])
564+ @pytest .mark .parametrize ("use_pytorch_backend" , [True , False ],
565+ ids = ["pytorch_backend" , "trt_backend" ])
566+ def test_trtllm_bench_llmapi_launch (llm_root , llm_venv , model_name ,
567+ model_subdir , use_pytorch_backend ):
568+ runner = BenchRunner (llm_root = llm_root ,
569+ llm_venv = llm_venv ,
570+ model_name = model_name ,
571+ model_subdir = model_subdir ,
572+ streaming = False ,
573+ use_pytorch_backend = use_pytorch_backend ,
574+ use_mpirun = True ,
575+ tp_size = 2 )
576+ runner ()
577+
578+
437579def trtllm_bench_prolog (
438580 llm_root ,
439581 llm_venv ,
@@ -664,14 +806,14 @@ def test_trtllm_bench_mgmn(llm_root, llm_venv):
664806 model_name = "meta-llama/Llama-3.1-8B"
665807 llama_model_dir = Path (
666808 llm_models_root ()) / "llama-3.1-model/Llama-3.1-8B-Instruct"
667- dataset_path = trtllm_bench_prolog (llm_root ,
668- llm_venv ,
669- engine_dir = None ,
670- model_subdir = llama_model_dir ,
671- model_name = model_name ,
672- quant = None ,
673- streaming = False ,
674- skip_engine_build = True )
809+ _ , _ , dataset_path = trtllm_bench_prolog (llm_root ,
810+ llm_venv ,
811+ engine_dir = None ,
812+ model_subdir = llama_model_dir ,
813+ model_name = model_name ,
814+ quant = None ,
815+ streaming = False ,
816+ skip_engine_build = True )
675817
676818 benchmark_cmd = \
677819 f"mpirun -n 2 trtllm-llmapi-launch trtllm-bench --model { model_name } " \
@@ -685,7 +827,10 @@ def test_trtllm_bench_mgmn(llm_root, llm_venv):
685827 dir = "./" ,
686828 delete = True ,
687829 delete_on_close = True ) as running_log :
688- check_call (benchmark_cmd , shell = True , stdout = running_log )
830+ check_call (benchmark_cmd ,
831+ shell = True ,
832+ running_log = running_log ,
833+ env = llm_venv ._new_env )
689834 _check_mem_usage (running_log , [30 , 0 , 0 , 0 ])
690835
691836
0 commit comments