Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 1253881

Browse files
author
DEKHTIARJonathan
committed
AutoTuner Mechanism Added
Performance Autotuning Adding concrete fn
1 parent 98efdc5 commit 1253881

File tree

5 files changed

+264
-80
lines changed

5 files changed

+264
-80
lines changed

tftrt/examples/benchmark_args.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,13 @@ def __init__(self):
272272
"performance analysis."
273273
)
274274

275+
self._add_bool_argument(
276+
name="tf_profile_verbose",
277+
default=False,
278+
required=False,
279+
help="If set to True, will add extra information to the TF Profile."
280+
)
281+
275282
self._add_bool_argument(
276283
name="debug",
277284
default=False,
@@ -378,6 +385,15 @@ def _validate_args(self, args):
378385
"calibration."
379386
)
380387

388+
if (
389+
args.tf_profile_verbose and
390+
args.tf_profile_export_path is None
391+
):
392+
raise ValueError(
393+
"`--tf_profile_verbose` can only be set if "
394+
"`--tf_profile_export_path=/path/to/export` is defined."
395+
)
396+
381397
def _post_process_args(self, args):
382398
return args
383399

tftrt/examples/benchmark_autotuner.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# -*- coding: utf-8 -*-
4+
5+
import time
6+
import numpy as np
7+
import tensorflow as tf
8+
9+
from benchmark_utils import force_gpu_resync
10+
11+
12+
class _TFFunctionAutoTuner(object):
13+
def __init__(self, funcs, calls_per_func, skip_n_first):
14+
if not isinstance(funcs, (tuple, list)):
15+
raise ValueError("Argument `funcs` must be a list or tuple.")
16+
17+
if any([not callable(fn) for fn in funcs]):
18+
raise ValueError("One of the function passed is not callable.")
19+
20+
self._fns = funcs
21+
self._calls_per_func = calls_per_func
22+
self._skip_n_first = skip_n_first
23+
24+
self._call_counter = 0
25+
self._timings = [[] for _ in range(len(self._fns))]
26+
27+
self._best_fn = self._autotune
28+
29+
def _autotune(self, *arg, **kwargs):
30+
fn_id = self._call_counter // self._calls_per_func
31+
try:
32+
start_t = time.time()
33+
output = self._fns[fn_id](*arg, **kwargs)
34+
self._timings[fn_id].append(time.time() - start_t)
35+
except IndexError:
36+
print("\n[DEBUG] AutoTuning is over... Collecting timing statistics:")
37+
perf_data = []
38+
for idx, fn_stat in enumerate(self._timings):
39+
perf_data.append(np.mean(fn_stat[self._skip_n_first:]))
40+
print(f"\t- [DEBUG] Function ID: {idx} - "
41+
f"Name: {self._fns[idx].__name__:40s} - "
42+
f"Average Exec Time: {perf_data[-1]}")
43+
44+
best_fn_id = np.argmin(perf_data)
45+
print(f"[DEBUG] Selecting function ID: {best_fn_id}. "
46+
f"Setting exec path to: `{self._fns[best_fn_id].__name__}`\n")
47+
48+
self._best_fn = self._fns[best_fn_id]
49+
return self._best_fn(*arg, **kwargs)
50+
51+
self._call_counter += 1
52+
return output
53+
54+
def __call__(self, *arg, **kwargs):
55+
return self._best_fn(*arg, **kwargs)
56+
57+
58+
def _force_using_concrete_function(func):
59+
# `context` needs to be a closure of type list or dict for persistance
60+
context = []
61+
def _wrapper(*args, **kwargs):
62+
try:
63+
return context[0](*args, **kwargs)
64+
except IndexError:
65+
print(f"[INFO] Building the concrete function")
66+
context.append(func.get_concrete_function(*args, **kwargs))
67+
return context[0](*args, **kwargs)
68+
return _wrapper
69+
70+
71+
def auto_tf_func_tuner(
72+
calls_per_func=45,
73+
skip_n_first=30,
74+
use_xla=False,
75+
use_synthetic_data=False
76+
):
77+
78+
def wrapper(func):
79+
80+
@force_gpu_resync
81+
def eager_function(*args, **kwargs):
82+
return func(*args, **kwargs)
83+
84+
@force_gpu_resync
85+
@tf.function(jit_compile=use_xla)
86+
def tf_function(*args, **kwargs):
87+
return func(*args, **kwargs)
88+
89+
@force_gpu_resync
90+
@_force_using_concrete_function
91+
@tf.function(jit_compile=use_xla)
92+
def tf_concrete_function(*args, **kwargs):
93+
return func(*args, **kwargs)
94+
95+
eager_function.__name__ = f"{func.__name__}_eager"
96+
tf_function.__name__ = f"{func.__name__}_tf_function"
97+
tf_concrete_function.__name__ = f"{func.__name__}_tf_concrete_function"
98+
99+
funcs2autotune = [eager_function, tf_function]
100+
if use_synthetic_data:
101+
print("[INFO] Allowing direct concrete_function call with "
102+
"synthetic data loader.")
103+
funcs2autotune.append(tf_concrete_function)
104+
105+
return _TFFunctionAutoTuner(
106+
funcs2autotune,
107+
calls_per_func=calls_per_func,
108+
skip_n_first=skip_n_first
109+
)
110+
111+
return wrapper

tftrt/examples/benchmark_runner.py

Lines changed: 97 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
from distutils.util import strtobool
1717

18+
from benchmark_autotuner import auto_tf_func_tuner
19+
1820
from benchmark_utils import DataAggregator
19-
from benchmark_utils import force_gpu_resync
2021
from benchmark_utils import print_dict
2122
from benchmark_utils import timed_section
2223

@@ -383,16 +384,14 @@ def execute_benchmark(self):
383384
dataset, bypass_data_to_eval = self.get_dataset_batches()
384385

385386
if self._args.use_synthetic_data:
386-
old_ds = dataset
387387
try:
388-
dataset = SyntheticDataset(old_ds, device="/gpu:0")
388+
dataset = SyntheticDataset(dataset, device="/gpu:0")
389389
self._debug_print(
390390
"Model dataset has been replaced by a synthetic data "
391391
"loader to minimize data loading jitter."
392392
)
393393

394394
except Exception as e:
395-
dataset = old_ds
396395
print(
397396
f"[ERROR] Impossible to transform the dataset into a "
398397
f"synthetic dataset. Performance numbers will be "
@@ -401,8 +400,10 @@ def execute_benchmark(self):
401400
else:
402401
dataset = ensure_dataset_on_gpu(dataset, device="GPU:0")
403402

404-
@force_gpu_resync
405-
@tf.function(jit_compile=self._args.use_xla)
403+
@auto_tf_func_tuner(
404+
use_xla=self._args.use_xla,
405+
use_synthetic_data=self._args.use_synthetic_data
406+
)
406407
def infer_batch(x):
407408
if isinstance(x, (tuple, list)):
408409
model_out = graph_func(*x)
@@ -439,72 +440,112 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):
439440
)
440441

441442
if self._args.tf_profile_export_path:
442-
profiling_ctx = tf.profiler.experimental.Profile(
443-
self._args.tf_profile_export_path
444-
)
443+
def start_profiling():
444+
if self._args.tf_profile_verbose:
445+
profiler_opts = tf.profiler.experimental.ProfilerOptions(
446+
# Ajust TraceMe levels:
447+
# - 1: critical
448+
# - 2: info [default]
449+
# - 3: verbose
450+
host_tracer_level=2,
451+
# Enables python function call tracing
452+
# - 0: disable [default]
453+
# - 1: enable
454+
python_tracer_level=1,
455+
# Adjust device (TPU/GPU) tracer level:
456+
# - 0: disable
457+
# - 1: enable [default]
458+
device_tracer_level=1,
459+
# start profiling after 15 sec.
460+
# - Skip tf.function building
461+
# - Skip autotuning
462+
delay_ms=30000
463+
)
464+
print("[INFO] Using verbose TF Profiler.")
465+
else:
466+
profiler_opts = None
467+
468+
profiling_ctx = tf.profiler.experimental.start(
469+
self._args.tf_profile_export_path,
470+
options=profiler_opts
471+
)
472+
473+
stop_profiling = tf.profiler.experimental.stop
474+
445475
tracing_ctx = tf.profiler.experimental.Trace
476+
446477
else:
478+
start_profiling = stop_profiling = lambda *a, **kw: None
447479
profiling_ctx = contextlib.nullcontext()
448480
tracing_ctx = lambda *a, **kw: contextlib.nullcontext()
449481

450482
step_idx = 0
451483
ds_iter = iter(dataset)
452484

453-
dequeue_batch_fn = get_dequeue_batch_fn(ds_iter)
485+
dequeue_batch_fn = get_dequeue_batch_fn(
486+
ds_iter,
487+
use_xla=self._args.use_xla,
488+
use_synthetic_data=self._args.use_synthetic_data
489+
)
490+
454491
force_data_on_gpu_fn = get_force_data_on_gpu_fn(
455492
device="/gpu:0",
456-
use_xla=self._args.use_xla
493+
use_xla=self._args.use_xla,
494+
use_synthetic_data=self._args.use_synthetic_data
457495
)
458496

459-
with profiling_ctx:
460-
461-
while True:
462-
463-
step_idx += 1
497+
while True:
464498

465-
if (self._args.num_iterations is not None and
466-
step_idx > self._args.num_iterations):
467-
break
468-
469-
with tracing_ctx('Inference Step', step_num=step_idx, _r=1):
499+
step_idx += 1
470500

471-
with tracing_ctx('Input Dequeueing', step_num=step_idx, _r=1):
472-
try:
473-
start_time = time.time()
474-
data_batch = dequeue_batch_fn()
475-
dequeue_times.append(time.time() - start_time)
476-
except (StopIteration, OutOfRangeError):
477-
print("[Exiting] Reached end of dataset ...")
478-
break
501+
if step_idx == self._args.num_warmup_iterations - 5:
502+
start_profiling()
479503

480-
with tracing_ctx('Inputs Preprocessing', step_num=step_idx, _r=1):
481-
x, y = self.preprocess_model_inputs(data_batch)
504+
if (
505+
self._args.num_iterations is not None and
506+
step_idx > self._args.num_iterations
507+
):
508+
break
482509

483-
with tracing_ctx('Inputs MemcpyHtoD', step_num=step_idx, _r=1):
484-
start_time = time.time()
485-
x = force_data_on_gpu_fn(x)
486-
memcopy_times.append(time.time() - start_time)
510+
with tracing_ctx('', step_num=step_idx, _r=1):
487511

488-
with tracing_ctx('GPU Inference', step_num=step_idx, _r=1):
512+
with tracing_ctx('Input Dequeueing'):
513+
try:
489514
start_time = time.time()
490-
y_pred = infer_batch(x)
491-
iter_times.append(time.time() - start_time)
492-
493-
if not self._args.debug_performance:
494-
log_step(
495-
step_idx,
496-
display_every=self._args.display_every,
497-
iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000,
498-
memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000,
499-
dequeue_time=np.mean(dequeue_times[-self._args.display_every:]) * 1000
500-
)
501-
else:
502-
print(f"{'GPU Iteration Time':18s}: {iter_times[-1]:08.4f}s")
503-
print(f"{'Data MemCopyHtoD Time':18s}: {memcpyHtoD_time[-1]:08.4f}s")
504-
print(f"{'Data Dequeue Time':18s}: {dequeue_times[-1]:08.4f}s")
515+
data_batch = dequeue_batch_fn()
516+
dequeue_times.append(time.time() - start_time)
517+
except (StopIteration, OutOfRangeError):
518+
print("[Exiting] Reached end of dataset ...")
519+
break
520+
521+
with tracing_ctx('Inputs Preprocessing'):
522+
x, y = self.preprocess_model_inputs(data_batch)
523+
524+
with tracing_ctx('Inputs MemcpyHtoD'):
525+
start_time = time.time()
526+
x = force_data_on_gpu_fn(x)
527+
memcopy_times.append(time.time() - start_time)
528+
529+
with tracing_ctx('GPU Inference'):
530+
start_time = time.time()
531+
y_pred = infer_batch(x)
532+
iter_times.append(time.time() - start_time)
533+
534+
if not self._args.debug_performance:
535+
log_step(
536+
step_idx,
537+
display_every=self._args.display_every,
538+
iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000,
539+
memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000,
540+
dequeue_time=np.mean(dequeue_times[-self._args.display_every:]) * 1000
541+
)
542+
else:
543+
print(f"{'GPU Iteration Time':18s}: {iter_times[-1]:08.4f}s")
544+
print(f"{'Data MemCopyHtoD Time':18s}: {memcpyHtoD_time[-1]:08.4f}s")
545+
print(f"{'Data Dequeue Time':18s}: {dequeue_times[-1]:08.4f}s")
505546

506-
if not self._args.use_synthetic_data:
507-
data_aggregator.aggregate_data(y_pred, y)
547+
if not self._args.use_synthetic_data:
548+
data_aggregator.aggregate_data(y_pred, y)
508549

509550
if (
510551
not self._args.debug_performance and
@@ -518,6 +559,9 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):
518559
dequeue_time=np.mean(dequeue_times[-self._args.display_every:]) * 1000
519560
)
520561

562+
if step_idx >= 100:
563+
stop_profiling()
564+
521565
with timed_section("Metric Computation"):
522566

523567
metrics = dict()

0 commit comments

Comments
 (0)