@@ -175,11 +175,17 @@ def run_ts_trt(model, input_tensors, params, precision, batch_size):
175175 "inputs" : input_tensors ,
176176 "enabled_precisions" : {precision_to_dtype (precision )},
177177 "truncate_long_and_double" : params .get ("truncate" , False ),
178+ "use_python_runtime" : params .get ("use_python_runtime" , False ),
178179 }
179180
180181 if precision == "int8" :
181182 compile_settings .update ({"calib" : params .get ("calibration_cache" )})
182183
184+ if params .get ("enable_cuda_graph" , False ):
185+ logging .warning (
186+ f"Torchscript backend doesn't support CUDA Graphs. `--enable_cuda_graph` will be ignored."
187+ )
188+
183189 start_compile = timeit .default_timer ()
184190 model = torchtrt .compile (model , ir = "ts" , ** compile_settings )
185191 end_compile = timeit .default_timer ()
@@ -217,19 +223,34 @@ def run_hf_dynamo(model, input_tensors, params, precision, batch_size):
217223 inputs = input_tensors ,
218224 enabled_precisions = {precision_to_dtype (precision )},
219225 truncate_double = params .get ("truncate" , False ),
226+ use_python_runtime = params .get ("use_python_runtime" , False ),
220227 )
221228 end_compile = timeit .default_timer ()
222229 compile_time_s = end_compile - start_compile
223- record_llm_perf (
224- trt_model ,
225- "Dynamo" ,
226- input_tensors ,
227- precision ,
228- osl ,
229- batch_size ,
230- iters ,
231- compile_time_s ,
232- )
230+
231+ if params .get ("enable_cuda_graph" , False ):
232+ with torchtrt .runtime .enable_cudagraphs (trt_model ) as cudagraphs_module :
233+ record_llm_perf (
234+ cudagraphs_module ,
235+ "Dynamo" ,
236+ input_tensors ,
237+ precision ,
238+ osl ,
239+ batch_size ,
240+ iters ,
241+ compile_time_s ,
242+ )
243+ else :
244+ record_llm_perf (
245+ trt_model ,
246+ "Dynamo" ,
247+ input_tensors ,
248+ precision ,
249+ osl ,
250+ batch_size ,
251+ iters ,
252+ compile_time_s ,
253+ )
233254
234255
235256@run_with_try_except
@@ -262,14 +283,27 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
262283 ),
263284 cache_built_engines = params .get ("cache_built_engines" , False ),
264285 reuse_cached_engines = params .get ("reuse_cached_engines" , False ),
286+ use_python_runtime = params .get ("use_python_runtime" , False ),
265287 )
266288 end_compile = timeit .default_timer ()
267289 compile_time_s = end_compile - start_compile
268290 iters = params .get ("iterations" , 20 )
269291
270- record_perf (
271- model , "Dynamo" , input_tensors , precision , iters , batch_size , compile_time_s
272- )
292+ if params .get ("enable_cuda_graph" , False ):
293+ with torchtrt .runtime .enable_cudagraphs (model ) as cudagraphs_module :
294+ record_perf (
295+ cudagraphs_module ,
296+ "Dynamo" ,
297+ input_tensors ,
298+ precision ,
299+ iters ,
300+ batch_size ,
301+ compile_time_s ,
302+ )
303+ else :
304+ record_perf (
305+ model , "Dynamo" , input_tensors , precision , iters , batch_size , compile_time_s
306+ )
273307
274308
275309@run_with_try_except
@@ -292,6 +326,7 @@ def run_torch_compile(model, input_tensors, params, precision, batch_size):
292326 "enabled_precisions" : {precision_to_dtype (precision )},
293327 "truncate" : params .get ("truncate" , False ),
294328 "min_block_size" : params .get ("min_block_size" , 1 ),
329+ "use_python_runtime" : params .get ("use_python_runtime" , False ),
295330 }
296331 start_compile = timeit .default_timer ()
297332 model = torch .compile (model , backend = "tensorrt" , dynamic = None , options = compile_spec )
@@ -300,15 +335,27 @@ def run_torch_compile(model, input_tensors, params, precision, batch_size):
300335 compile_time_s = end_compile - start_compile
301336 iters = params .get ("iterations" , 20 )
302337
303- record_perf (
304- model ,
305- "torch_compile" ,
306- input_tensors ,
307- precision ,
308- iters ,
309- batch_size ,
310- compile_time_s ,
311- )
338+ if params .get ("enable_cuda_graph" , False ):
339+ with torchtrt .runtime .enable_cudagraphs (model ) as cudagraphs_module :
340+ record_perf (
341+ cudagraphs_module ,
342+ "torch_compile" ,
343+ input_tensors ,
344+ precision ,
345+ iters ,
346+ batch_size ,
347+ compile_time_s ,
348+ )
349+ else :
350+ record_perf (
351+ model ,
352+ "torch_compile" ,
353+ input_tensors ,
354+ precision ,
355+ iters ,
356+ batch_size ,
357+ compile_time_s ,
358+ )
312359
313360
314361@run_with_try_except
@@ -320,9 +367,13 @@ def run_hf_inductor(model, input_tensors, params, precision, batch_size):
320367 # Mark dynamic shapes for input sequence
321368 input_seq = input_tensors [0 ]
322369 torch ._dynamo .mark_dynamic (input_seq , 1 , min = 1 , max = osl )
370+ mode = "max-autotune"
371+ if params .get ("enable_cuda_graph" , False ):
372+ mode = "reduce-overhead"
373+
323374 start_compile = timeit .default_timer ()
324375 # Compile the model
325- model = torch .compile (model , backend = "inductor" , dynamic = None , mode = "max-autotune" )
376+ model = torch .compile (model , backend = "inductor" , dynamic = None , mode = mode )
326377 model (input_seq )
327378 end_compile = timeit .default_timer ()
328379 compile_time_s = end_compile - start_compile
@@ -356,15 +407,25 @@ def run_inductor(model, input_tensors, params, precision, batch_size):
356407 if params ["is_text_llm" ]:
357408 return run_hf_inductor (model , input_tensors , params , precision , batch_size )
358409
410+ mode = "max-autotune"
411+ if params .get ("enable_cuda_graph" , False ):
412+ mode = "reduce-overhead"
413+
359414 start_compile = timeit .default_timer ()
360- model = torch .compile (model , backend = "inductor" , dynamic = None , mode = "max-autotune" )
415+ model = torch .compile (model , backend = "inductor" , dynamic = None , mode = mode )
361416 model (* input_tensors )
362417 end_compile = timeit .default_timer ()
363418 compile_time_s = end_compile - start_compile
364419 iters = params .get ("iterations" , 20 )
365420
366421 record_perf (
367- model , "inductor" , input_tensors , precision , iters , batch_size , compile_time_s
422+ model ,
423+ "inductor" ,
424+ input_tensors ,
425+ precision ,
426+ iters ,
427+ batch_size ,
428+ compile_time_s ,
368429 )
369430
370431
@@ -587,6 +648,16 @@ def run(
587648 action = "store_true" ,
588649 help = "Boolean flag to determine if the user provided model is a TRT engine or not" ,
589650 )
651+ arg_parser .add_argument (
652+ "--use_python_runtime" ,
653+ action = "store_true" ,
654+ help = "Whether to use Python runtime or not. Using C++ runtime by default" ,
655+ )
656+ arg_parser .add_argument (
657+ "--enable_cuda_graph" ,
658+ action = "store_true" ,
659+ help = "Whether to enable CUDA Graph. It is not used by default" ,
660+ )
590661 arg_parser .add_argument (
591662 "--report" ,
592663 type = str ,
0 commit comments