Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 98efdc5

Browse files
author
DEKHTIARJonathan
committed
Adding Inference Template
1 parent bd74e79 commit 98efdc5

File tree

2 files changed

+165
-0
lines changed

2 files changed

+165
-0
lines changed

tftrt/examples/template/infer.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#!# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# =============================================================================
17+
18+
import os
19+
import sys
20+
21+
import numpy as np
22+
23+
import tensorflow as tf
24+
25+
# Allow import of top level python files
26+
import inspect
27+
28+
currentdir = os.path.dirname(
29+
os.path.abspath(inspect.getfile(inspect.currentframe()))
30+
)
31+
parentdir = os.path.dirname(currentdir)
32+
sys.path.insert(0, parentdir)
33+
34+
from benchmark_args import BaseCommandLineAPI
35+
from benchmark_runner import BaseBenchmarkRunner
36+
37+
38+
class CommandLineAPI(BaseCommandLineAPI):
39+
40+
def __init__(self):
41+
super(CommandLineAPI, self).__init__()
42+
43+
# self._parser.add_argument(
44+
# "--sequence_length",
45+
# type=int,
46+
# default=128,
47+
# help="Input data sequence length."
48+
# )
49+
50+
51+
class BenchmarkRunner(BaseBenchmarkRunner):
52+
53+
def get_dataset_batches(self):
54+
"""Returns a list of batches of input samples.
55+
56+
Each batch should be in the form [x, y], where
57+
x is a numpy array of the input samples for the batch, and
58+
y is a numpy array of the expected model outputs for the batch
59+
60+
Returns:
61+
- dataset: a TF Dataset object
62+
- bypass_data_to_eval: any object type that will be passed unmodified to
63+
`evaluate_result()`. If not necessary: `None`
64+
65+
Note: script arguments can be accessed using `self._args.attr`
66+
"""
67+
68+
# seq = generate_a_sequence(self._args.sequence_length)
69+
70+
# - https://www.tensorflow.org/guide/data_performance
71+
# - https://www.tensorflow.org/guide/data
72+
# dataset = tf.data....
73+
74+
return dataset, None
75+
76+
def preprocess_model_inputs(self, data_batch):
77+
"""This function prepare the `data_batch` generated from the dataset.
78+
Returns:
79+
x: input of the model
80+
y: data to be used for model evaluation
81+
82+
Note: script arguments can be accessed using `self._args.attr`
83+
"""
84+
85+
x = data_batch
86+
return x, None
87+
88+
def postprocess_model_outputs(self, predictions, expected):
89+
"""Post process if needed the predictions and expected tensors. At the
90+
minimum, this function transforms all TF Tensors into a numpy arrays.
91+
Most models will not need to modify this function.
92+
93+
Note: script arguments can be accessed using `self._args.attr`
94+
"""
95+
96+
# NOTE : DO NOT MODIFY FOR NOW => We do not measure accuracy right now
97+
98+
return predictions.numpy(), expected.numpy()
99+
100+
def evaluate_model(self, predictions, expected, bypass_data_to_eval):
101+
"""Evaluate result predictions for entire dataset.
102+
103+
This computes overall accuracy, mAP, etc. Returns the
104+
metric value and a metric_units string naming the metric.
105+
106+
Note: script arguments can be accessed using `self._args.attr`
107+
"""
108+
109+
# NOTE: PLEASE ONLY MODIFY THE NAME OF THE ACCURACY METRIC
110+
111+
return None, "<ACCURACY METRIC NAME>"
112+
113+
114+
if __name__ == '__main__':
115+
116+
cmdline_api = CommandLineAPI()
117+
args = cmdline_api.parse_args()
118+
119+
runner = BenchmarkRunner(args)
120+
runner.execute_benchmark()
121+
122+
################ TO BE REMOVED - HIGH LEVEL CONCEPT #####################
123+
124+
import time
125+
126+
model_fn = load_my_model("/path/to/my/model")
127+
128+
dataset, _ = get_dataset_batches() # dataset, None
129+
130+
ds_iter = iter(dataset)
131+
132+
for idx, batch in enumerate(ds_iter):
133+
print(f"Batch ID: {idx + 1} - Data: {batch}")
134+
135+
# - IF NEEDED - This transforms the inputs - Most cases it doesn't do anything
136+
# let's say transforming a list into a dict() or reverse
137+
batch = preprocess_model_inputs(batch)
138+
139+
start_t = time.time()
140+
outputs = model_fn(batch)
141+
print(f"Inference Time: {(time.time() - start_t)*1000:.1f}ms") # 0.001
142+
143+
## post my outputs to "measure accuracy"
144+
## note: we skip that
145+
146+
print("Success")
147+
sys.exit(0)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
3+
nvidia-smi
4+
5+
set -x
6+
7+
BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
8+
9+
python ${BASE_DIR}/infer.py \
10+
--data_dir=/path/to/script \
11+
--input_saved_model_dir=/path/to/saved_model \
12+
--batch_size=<BATCH_SIWZE> \
13+
--output_tensors_name="logits,probs" \
14+
`# The following is set because we will be running synthetic benchmarks` \
15+
--total_max_samples=1000 \
16+
--use_synthetic_data \
17+
--num_iterations=1000 \
18+
${@}

0 commit comments

Comments
 (0)