Skip to content

Commit 778e52f

Browse files
No public description
PiperOrigin-RevId: 639040818
1 parent bb19812 commit 778e52f

File tree

2 files changed

+109
-61
lines changed

2 files changed

+109
-61
lines changed

official/vision/serving/detection.py

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,37 @@ def _padded_size(self):
4141
return self._input_image_size
4242

4343
def _build_model(self):
44-
4544
nms_versions_supporting_dynamic_batch_size = {'batched', 'v2', 'v3'}
4645
nms_version = self.params.task.model.detection_generator.nms_version
47-
if (self._batch_size is None and
48-
nms_version not in nms_versions_supporting_dynamic_batch_size):
49-
logging.info('nms_version is set to `batched` because `%s` '
50-
'does not support with dynamic batch size.', nms_version)
46+
if (
47+
self._batch_size is None
48+
and nms_version not in nms_versions_supporting_dynamic_batch_size
49+
):
50+
logging.info(
51+
'nms_version is set to `batched` because `%s` '
52+
'does not support with dynamic batch size.',
53+
nms_version,
54+
)
5155
self.params.task.model.detection_generator.nms_version = 'batched'
5256

53-
input_specs = tf_keras.layers.InputSpec(shape=[
54-
self._batch_size, *self._padded_size, 3])
57+
input_specs = tf_keras.layers.InputSpec(
58+
shape=[self._batch_size, *self._padded_size, 3]
59+
)
5560

5661
if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
5762
model = factory.build_maskrcnn(
58-
input_specs=input_specs, model_config=self.params.task.model)
63+
input_specs=input_specs, model_config=self.params.task.model
64+
)
5965
elif isinstance(self.params.task.model, configs.retinanet.RetinaNet):
6066
model = factory.build_retinanet(
61-
input_specs=input_specs, model_config=self.params.task.model)
67+
input_specs=input_specs, model_config=self.params.task.model
68+
)
6269
else:
63-
raise ValueError('Detection module not implemented for {} model.'.format(
64-
type(self.params.task.model)))
70+
raise ValueError(
71+
'Detection module not implemented for {} model.'.format(
72+
type(self.params.task.model)
73+
)
74+
)
6575

6676
return model
6777

@@ -73,7 +83,8 @@ def _build_anchor_boxes(self):
7383
max_level=model_params.max_level,
7484
num_scales=model_params.anchor.num_scales,
7585
aspect_ratios=model_params.anchor.aspect_ratios,
76-
anchor_size=model_params.anchor.anchor_size)
86+
anchor_size=model_params.anchor.anchor_size,
87+
)
7788
return input_anchor(image_size=self._padded_size)
7889

7990
def _build_inputs(self, image):
@@ -85,7 +96,8 @@ def _build_inputs(self, image):
8596

8697
# Normalizes image with mean and std pixel values.
8798
image = preprocess_ops.normalize_image(
88-
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)
99+
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB
100+
)
89101

90102
image, image_info = preprocess_ops.resize_and_crop_image(
91103
image,
@@ -131,20 +143,24 @@ def preprocess(
131143
132144
Args:
133145
images: The images tensor.
146+
134147
Returns:
135148
images: The images tensor cast to float.
136149
anchor_boxes: Dict mapping anchor levels to anchor boxes.
137150
image_info: Tensor containing the details of the image resizing.
138-
139151
"""
140152
model_params = self.params.task.model
141153
with tf.device('cpu:0'):
142154
# Tensor Specs for map_fn outputs (images, anchor_boxes, and image_info).
143-
images_spec = tf.TensorSpec(shape=self._padded_size + [3],
144-
dtype=tf.float32)
155+
images_spec = tf.TensorSpec(
156+
shape=self._padded_size + [3], dtype=tf.float32
157+
)
145158

146-
num_anchors = model_params.anchor.num_scales * len(
147-
model_params.anchor.aspect_ratios) * 4
159+
num_anchors = (
160+
model_params.anchor.num_scales
161+
* len(model_params.anchor.aspect_ratios)
162+
* 4
163+
)
148164
anchor_shapes = []
149165
for level in range(model_params.min_level, model_params.max_level + 1):
150166
anchor_level_spec = tf.TensorSpec(
@@ -153,7 +169,8 @@ def preprocess(
153169
math.ceil(self._padded_size[1] / 2**level),
154170
num_anchors,
155171
],
156-
dtype=tf.float32)
172+
dtype=tf.float32,
173+
)
157174
anchor_shapes.append((str(level), anchor_level_spec))
158175

159176
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
@@ -163,9 +180,14 @@ def preprocess(
163180
tf.map_fn(
164181
self._build_inputs,
165182
elems=images,
166-
fn_output_signature=(images_spec, dict(anchor_shapes),
167-
image_info_spec),
168-
parallel_iterations=32))
183+
fn_output_signature=(
184+
images_spec,
185+
dict(anchor_shapes),
186+
image_info_spec,
187+
),
188+
parallel_iterations=32,
189+
),
190+
)
169191

170192
return images, anchor_boxes, image_info
171193

@@ -174,6 +196,7 @@ def serve(self, images: tf.Tensor):
174196
175197
Args:
176198
images: uint8 Tensor of shape [batch_size, None, None, 3]
199+
177200
Returns:
178201
Tensor holding detection output logits.
179202
"""
@@ -190,10 +213,15 @@ def serve(self, images: tf.Tensor):
190213
# [desired_height, desired_width], [y_scale, x_scale],
191214
# [y_offset, x_offset]]. When input_type is tflite, input image is
192215
# supposed to be preprocessed already.
193-
image_info = tf.convert_to_tensor([[
194-
self._input_image_size, self._input_image_size, [1.0, 1.0], [0, 0]
195-
]],
196-
dtype=tf.float32)
216+
image_info = tf.convert_to_tensor(
217+
[[
218+
self._input_image_size,
219+
self._input_image_size,
220+
[1.0, 1.0],
221+
[0, 0],
222+
]],
223+
dtype=tf.float32,
224+
)
197225
input_image_shape = image_info[:, 1, :]
198226

199227
# To overcome keras.Model extra limitation to save a model with layers that
@@ -226,20 +254,23 @@ def serve(self, images: tf.Tensor):
226254
# point outputs.
227255
if export_config.cast_num_detections_to_float:
228256
detections['num_detections'] = tf.cast(
229-
detections['num_detections'], dtype=tf.float32)
257+
detections['num_detections'], dtype=tf.float32
258+
)
230259
if export_config.cast_detection_classes_to_float:
231260
detections['detection_classes'] = tf.cast(
232-
detections['detection_classes'], dtype=tf.float32)
261+
detections['detection_classes'], dtype=tf.float32
262+
)
233263

234264
final_outputs = {
235265
'detection_boxes': detections['detection_boxes'],
236266
'detection_scores': detections['detection_scores'],
237267
'detection_classes': detections['detection_classes'],
238-
'num_detections': detections['num_detections']
268+
'num_detections': detections['num_detections'],
239269
}
240270
if 'detection_outer_boxes' in detections:
241-
final_outputs['detection_outer_boxes'] = (
242-
detections['detection_outer_boxes'])
271+
final_outputs['detection_outer_boxes'] = detections[
272+
'detection_outer_boxes'
273+
]
243274
else:
244275
# For RetinaNet model, apply export_config.
245276
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
@@ -250,7 +281,7 @@ def serve(self, images: tf.Tensor):
250281
detections = self._normalize_coordinates(detections, keys, image_info)
251282
final_outputs = {
252283
'decoded_boxes': detections['decoded_boxes'],
253-
'decoded_box_scores': detections['decoded_box_scores']
284+
'decoded_box_scores': detections['decoded_box_scores'],
254285
}
255286

256287
if 'detection_masks' in detections.keys():

official/vision/serving/export_saved_model_lib.py

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
r"""Vision models export utility function for serving/inference."""
1616

1717
import os
18-
from typing import Optional, List, Union, Text, Dict
18+
from typing import Dict, List, Optional, Union
1919

2020
from absl import logging
2121
import tensorflow as tf, tf_keras
@@ -45,7 +45,7 @@ def export_inference_graph(
4545
log_model_flops_and_params: bool = False,
4646
checkpoint: Optional[tf.train.Checkpoint] = None,
4747
input_name: Optional[str] = None,
48-
function_keys: Optional[Union[List[Text], Dict[Text, Text]]] = None,
48+
function_keys: Optional[Union[List[str], Dict[str, str]]] = None,
4949
add_tpu_function_alias: Optional[bool] = False,
5050
):
5151
"""Exports inference graph for the model specified in the exp config.
@@ -83,57 +83,68 @@ def export_inference_graph(
8383

8484
if export_checkpoint_subdir:
8585
output_checkpoint_directory = os.path.join(
86-
export_dir, export_checkpoint_subdir)
86+
export_dir, export_checkpoint_subdir
87+
)
8788
else:
8889
output_checkpoint_directory = None
8990

9091
if export_saved_model_subdir:
9192
output_saved_model_directory = os.path.join(
92-
export_dir, export_saved_model_subdir)
93+
export_dir, export_saved_model_subdir
94+
)
9395
else:
9496
output_saved_model_directory = export_dir
9597

9698
# TODO(arashwan): Offers a direct path to use ExportModule with Task objects.
9799
if not export_module:
98-
if isinstance(params.task,
99-
configs.image_classification.ImageClassificationTask):
100+
if isinstance(
101+
params.task, configs.image_classification.ImageClassificationTask
102+
):
100103
export_module = image_classification.ClassificationModule(
101104
params=params,
102105
batch_size=batch_size,
103106
input_image_size=input_image_size,
104107
input_type=input_type,
105108
num_channels=num_channels,
106-
input_name=input_name)
109+
input_name=input_name,
110+
)
107111
elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
108-
params.task, configs.maskrcnn.MaskRCNNTask):
112+
params.task, configs.maskrcnn.MaskRCNNTask
113+
):
109114
export_module = detection.DetectionModule(
110115
params=params,
111116
batch_size=batch_size,
112117
input_image_size=input_image_size,
113118
input_type=input_type,
114119
num_channels=num_channels,
115-
input_name=input_name)
116-
elif isinstance(params.task,
117-
configs.semantic_segmentation.SemanticSegmentationTask):
120+
input_name=input_name,
121+
)
122+
elif isinstance(
123+
params.task, configs.semantic_segmentation.SemanticSegmentationTask
124+
):
118125
export_module = semantic_segmentation.SegmentationModule(
119126
params=params,
120127
batch_size=batch_size,
121128
input_image_size=input_image_size,
122129
input_type=input_type,
123130
num_channels=num_channels,
124-
input_name=input_name)
125-
elif isinstance(params.task,
126-
configs.video_classification.VideoClassificationTask):
131+
input_name=input_name,
132+
)
133+
elif isinstance(
134+
params.task, configs.video_classification.VideoClassificationTask
135+
):
127136
export_module = video_classification.VideoClassificationModule(
128137
params=params,
129138
batch_size=batch_size,
130139
input_image_size=input_image_size,
131140
input_type=input_type,
132141
num_channels=num_channels,
133-
input_name=input_name)
142+
input_name=input_name,
143+
)
134144
else:
135-
raise ValueError('Export module not implemented for {} task.'.format(
136-
type(params.task)))
145+
raise ValueError(
146+
'Export module not implemented for {} task.'.format(type(params.task))
147+
)
137148

138149
if add_tpu_function_alias:
139150
if input_type == 'image_tensor':
@@ -160,7 +171,8 @@ def export_inference_graph(
160171
checkpoint=checkpoint,
161172
checkpoint_path=checkpoint_path,
162173
timestamped=False,
163-
save_options=save_options)
174+
save_options=save_options,
175+
)
164176

165177
if output_checkpoint_directory:
166178
ckpt = tf.train.Checkpoint(model=export_module.model)
@@ -171,16 +183,16 @@ def export_inference_graph(
171183
inputs_kwargs = None
172184
if isinstance(
173185
params.task,
174-
(configs.retinanet.RetinaNetTask, configs.maskrcnn.MaskRCNNTask)):
186+
(configs.retinanet.RetinaNetTask, configs.maskrcnn.MaskRCNNTask),
187+
):
175188
# We need to create inputs_kwargs argument to specify the input shapes for
176189
# subclass model that overrides model.call to take multiple inputs,
177190
# e.g., RetinaNet model.
178191
inputs_kwargs = {
179-
'images':
180-
tf.TensorSpec([1] + input_image_size + [num_channels],
181-
tf.float32),
182-
'image_shape':
183-
tf.TensorSpec([1, 2], tf.float32)
192+
'images': tf.TensorSpec(
193+
[1] + input_image_size + [num_channels], tf.float32
194+
),
195+
'image_shape': tf.TensorSpec([1, 2], tf.float32),
184196
}
185197
dummy_inputs = {
186198
k: tf.ones(v.shape.as_list(), tf.float32)
@@ -191,9 +203,14 @@ def export_inference_graph(
191203
else:
192204
logging.info(
193205
'Logging model flops and params not implemented for %s task.',
194-
type(params.task))
206+
type(params.task),
207+
)
195208
return
196-
train_utils.try_count_flops(export_module.model, inputs_kwargs,
197-
os.path.join(export_dir, 'model_flops.txt'))
198-
train_utils.write_model_params(export_module.model,
199-
os.path.join(export_dir, 'model_params.txt'))
209+
train_utils.try_count_flops(
210+
export_module.model,
211+
inputs_kwargs,
212+
os.path.join(export_dir, 'model_flops.txt'),
213+
)
214+
train_utils.write_model_params(
215+
export_module.model, os.path.join(export_dir, 'model_params.txt')
216+
)

0 commit comments

Comments
 (0)