Skip to content

Commit 2f9eb86

Browse files
authored
Port Faster R-CNN to Keras3 (#2458)
* Base structure for faster rcnn till rpn head * Add export for Faster RNN * add init file * initalize faster rcnn at model level * code fix fo roi align * Forward Pass code for Faster R-CNN * Faster RCNN Base code for Keras3(Draft-1) * Add local batch size * Add parameters to RPN Head * Make FPN more customizable with parameters and remove redudant code * Compute output shape for ROI Generator * Faster RCNN functional model with required import corrections * add clip boxes to forward pass * add prediction decoder and use "yxyx" as default internal bounding box format * feature pryamid correction * change ops.divide to ops.divide_no_nan * use from logits=True for Non Max supression * include box convertions for both rois and ground truth boxes * Change number of detections in decoder * Use categoricalcrossentropy to avoid -1 class error + added get_config for model saving * add basic test cases + linting * Add seed generator for sampling in RPN label encoding and ROI sampling layers * Use only spatial dimension for ops.nn.avg_pool + use ops.convert_to_tensor for list type + linting * Convert list to tensor using keras ops * Remove seed number from seed generator * Remove print and add proper comments * - Use stddev(0.01) as per paper across RPN and R-CNN Heads - Maxpool2d as per torch implementation in FPN - Add prediction decoder * - Fixes slice for multi backend - Slice for tensorflow can use [-1, -1, -1] for shape but not jax and torch, they should have explicit shape * - Add compute metrics method * Correct test cases and add missing args * Fix lint issues * - Fix lint and remove hard coded params to make it user friendly. * - Generate ROI's while decoding for predictions - Liniting + Test Cases * - Add faster rcnn to build method * - Test only for Keras3 * - Correct test case - Add copyright * - Correct the test cases decorator to skip for Keras2 * - Skip Legacy test cases - Fix ROI Align ops for torch backend * - Remove unecessary import in legacy code to fix lint * - Correct pytest complexity - Make bounding box test utils use 256,256 image size * - FIx Image Shape to 512, 512 default which will not break other test cases * - Lower image sizes for test cases - Add build method for fpn * - fix keras to 3.3.3 version * - Generate api - Correct YOLOv8 preset test case * - Lint fix * - Increase the atol, rtol for YOLOv8 Detector forward pass
1 parent 3d417ea commit 2f9eb86

File tree

27 files changed

+2085
-85
lines changed

27 files changed

+2085
-85
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ jobs:
9494
keras_cv/src/models/classification \
9595
keras_cv/src/models/object_detection/retinanet \
9696
keras_cv/src/models/object_detection/yolo_v8 \
97+
keras_cv/src/models/object_detection/faster_rcnn \
9798
keras_cv/src/models/object_detection_3d \
9899
keras_cv/src/models/segmentation \
99100
--durations 0

.kokoro/github/ubuntu/gpu/build.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,23 @@ then
3636
pip install -r requirements-tensorflow-cuda.txt --progress-bar off --timeout 1000
3737
pip install keras-nlp-nightly --no-deps
3838
pip install tensorflow-text~=2.16.0
39+
pip install keras~=3.3.3
3940

4041
elif [ "$KERAS_BACKEND" == "jax" ]
4142
then
4243
echo "JAX backend detected."
4344
pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000
4445
pip install keras-nlp-nightly --no-deps
4546
pip install tensorflow-text~=2.16.0
47+
pip install keras~=3.3.3
4648

4749
elif [ "$KERAS_BACKEND" == "torch" ]
4850
then
4951
echo "PyTorch backend detected."
5052
pip install -r requirements-torch-cuda.txt --progress-bar off --timeout 1000
5153
pip install keras-nlp-nightly --no-deps
5254
pip install tensorflow-text~=2.16.0
55+
pip install keras~=3.3.3
5356
fi
5457

5558
pip install --no-deps -e "." --progress-bar off
@@ -67,6 +70,7 @@ then
6770
keras_cv/src/models/classification \
6871
keras_cv/src/models/object_detection/retinanet \
6972
keras_cv/src/models/object_detection/yolo_v8 \
73+
keras_cv/src/models/object_detection/faster_rcnn \
7074
keras_cv/src/models/object_detection_3d \
7175
keras_cv/src/models/segmentation \
7276
keras_cv/src/models/feature_extractor/clip \
@@ -82,6 +86,7 @@ else
8286
keras_cv/src/models/classification \
8387
keras_cv/src/models/object_detection/retinanet \
8488
keras_cv/src/models/object_detection/yolo_v8 \
89+
keras_cv/src/models/object_detection/faster_rcnn \
8590
keras_cv/src/models/object_detection_3d \
8691
keras_cv/src/models/segmentation \
8792
keras_cv/src/models/feature_extractor/clip \

keras_cv/api/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from keras_cv.api.models import classification
8+
from keras_cv.api.models import faster_rcnn
89
from keras_cv.api.models import feature_extractor
910
from keras_cv.api.models import object_detection
1011
from keras_cv.api.models import retinanet
@@ -205,6 +206,9 @@
205206
from keras_cv.src.models.classification.image_classifier import ImageClassifier
206207
from keras_cv.src.models.classification.video_classifier import VideoClassifier
207208
from keras_cv.src.models.feature_extractor.clip.clip_model import CLIP
209+
from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import (
210+
FasterRCNN,
211+
)
208212
from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet
209213
from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_backbone import (
210214
YOLOV8Backbone,
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""DO NOT EDIT.
2+
3+
This file was autogenerated. Do not edit it by hand,
4+
since your modifications would be overwritten.
5+
"""
6+
7+
from keras_cv.src.models.object_detection.faster_rcnn.feature_pyramid import (
8+
FeaturePyramid,
9+
)
10+
from keras_cv.src.models.object_detection.faster_rcnn.rcnn_head import RCNNHead
11+
from keras_cv.src.models.object_detection.faster_rcnn.rpn_head import RPNHead

keras_cv/api/models/object_detection/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
since your modifications would be overwritten.
55
"""
66

7+
from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import (
8+
FasterRCNN,
9+
)
710
from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet
811
from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_detector import (
912
YOLOV8Detector,

keras_cv/src/bounding_box/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def _clip_boxes(boxes, box_format, image_shape):
141141

142142
if isinstance(image_shape, list) or isinstance(image_shape, tuple):
143143
height, width, _ = image_shape
144-
max_length = [height, width, height, width]
144+
max_length = ops.stack([height, width, height, width], axis=-1)
145145
else:
146146
image_shape = ops.cast(image_shape, dtype=boxes.dtype)
147147
height = image_shape[0]

keras_cv/src/layers/object_detection/roi_align.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ def _feature_bilinear_interpolation(features, kernel_y, kernel_x):
6969
features,
7070
[batch_size * num_boxes, output_size * 2, output_size * 2, num_filters],
7171
)
72-
features = ops.nn.average_pool(
73-
features, [1, 2, 2, 1], [1, 2, 2, 1], "VALID"
74-
)
72+
features = ops.nn.average_pool(features, (2, 2), (2, 2), "VALID")
7573
features = ops.reshape(
7674
features, [batch_size, num_boxes, output_size, output_size, num_filters]
7775
)
@@ -242,6 +240,11 @@ def multilevel_crop_and_resize(
242240
for i in range(len(feature_widths) - 1):
243241
level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i])
244242
batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1]
243+
244+
level_dim_offsets = ops.convert_to_tensor(level_dim_offsets)
245+
feature_widths = ops.convert_to_tensor(feature_widths)
246+
feature_heights = ops.convert_to_tensor(feature_heights)
247+
245248
level_dim_offsets = (
246249
ops.ones_like(level_dim_offsets, dtype="int32") * level_dim_offsets
247250
)
@@ -259,7 +262,9 @@ def multilevel_crop_and_resize(
259262
# following the FPN paper to divide by 224.
260263
levels = ops.cast(
261264
ops.floor_divide(
262-
ops.log(ops.divide(areas_sqrt, 224.0)),
265+
ops.log(
266+
ops.divide_no_nan(areas_sqrt, ops.convert_to_tensor(224.0))
267+
),
263268
ops.log(2.0),
264269
)
265270
+ 4.0,
@@ -292,12 +297,18 @@ def multilevel_crop_and_resize(
292297
ops.concatenate(
293298
[
294299
ops.expand_dims(
295-
[[ops.cast(max_feature_height, "float32")]] / level_strides
300+
ops.convert_to_tensor(
301+
[[ops.cast(max_feature_height, "float32")]]
302+
)
303+
/ level_strides
296304
- 1,
297305
axis=-1,
298306
),
299307
ops.expand_dims(
300-
[[ops.cast(max_feature_width, "float32")]] / level_strides
308+
ops.convert_to_tensor(
309+
[[ops.cast(max_feature_width, "float32")]]
310+
)
311+
/ level_strides
301312
- 1,
302313
axis=-1,
303314
),
@@ -357,7 +368,7 @@ def multilevel_crop_and_resize(
357368
# TODO(tanzhenyu): replace tf.gather with tf.gather_nd and try to get
358369
# similar performance.
359370
features_per_box = ops.reshape(
360-
ops.take(features_r2, indices),
371+
ops.take(features_r2, indices, axis=0),
361372
[
362373
batch_size,
363374
num_boxes,
@@ -378,7 +389,7 @@ def multilevel_crop_and_resize(
378389
# performance as this is mostly a duplicate of
379390
# https://github.com/tensorflow/models/blob/master/official/legacy/detection/ops/spatial_transform_ops.py#L324
380391
@keras.utils.register_keras_serializable(package="keras_cv")
381-
class _ROIAligner(keras.layers.Layer):
392+
class ROIAligner(keras.layers.Layer):
382393
"""Performs ROIAlign for the second stage processing."""
383394

384395
def __init__(
@@ -397,13 +408,11 @@ def __init__(
397408
sample_offset: A `float` in [0, 1] of the subpixel sample offset.
398409
**kwargs: Additional keyword arguments passed to Layer.
399410
"""
400-
# assert_tf_keras("keras_cv.layers._ROIAligner")
401-
self._config_dict = {
402-
"bounding_box_format": bounding_box_format,
403-
"crop_size": target_size,
404-
"sample_offset": sample_offset,
405-
}
406411
super().__init__(**kwargs)
412+
self.bounding_box_format = bounding_box_format
413+
self.target_size = target_size
414+
self.sample_offset = sample_offset
415+
self.built = True
407416

408417
def call(
409418
self,
@@ -427,16 +436,22 @@ def call(
427436
"""
428437
boxes = bounding_box.convert_format(
429438
boxes,
430-
source=self._config_dict["bounding_box_format"],
439+
source=self.bounding_box_format,
431440
target="yxyx",
432441
)
433442
roi_features = multilevel_crop_and_resize(
434443
features,
435444
boxes,
436-
output_size=self._config_dict["crop_size"],
437-
sample_offset=self._config_dict["sample_offset"],
445+
output_size=self.target_size,
446+
sample_offset=self.sample_offset,
438447
)
439448
return roi_features
440449

441450
def get_config(self):
442-
return self._config_dict
451+
config = super().get_config()
452+
config["bounding_box_format"] = self.bounding_box_format
453+
config["target_size"] = self.target_size
454+
config["sample_offset"] = self.sample_offset
455+
456+
def compute_output_shape(self, input_shape):
457+
return (None, None, self.target_size, self.target_size, 256)

keras_cv/src/layers/object_detection/roi_generator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class ROIGenerator(keras.layers.Layer):
6868
applying NMS in inference mode. When RPN is run on multiple
6969
feature maps / levels (as in FPN) this number is per
7070
feature map / level.
71+
nms_from_logits: bool. True means input score is logits, False means confidence.
7172
7273
Example:
7374
```python
@@ -90,6 +91,7 @@ def __init__(
9091
nms_score_threshold_test: float = 0.0,
9192
nms_iou_threshold_test: float = 0.7,
9293
post_nms_topk_test: int = 1000,
94+
nms_from_logits: bool = False,
9395
**kwargs,
9496
):
9597
super().__init__(**kwargs)
@@ -102,6 +104,7 @@ def __init__(
102104
self.nms_score_threshold_test = nms_score_threshold_test
103105
self.nms_iou_threshold_test = nms_iou_threshold_test
104106
self.post_nms_topk_test = post_nms_topk_test
107+
self.nms_from_logits = nms_from_logits
105108
self.built = True
106109

107110
def call(
@@ -158,7 +161,7 @@ def per_level_gen(boxes, scores):
158161
# TODO(tanzhenyu): consider supporting soft / batched nms for accl
159162
boxes = NonMaxSuppression(
160163
bounding_box_format=self.bounding_box_format,
161-
from_logits=False,
164+
from_logits=self.nms_from_logits,
162165
iou_threshold=nms_iou_threshold,
163166
confidence_threshold=nms_score_threshold,
164167
max_detections=level_post_nms_topk,
@@ -191,6 +194,9 @@ def per_level_gen(boxes, scores):
191194

192195
return rois, roi_scores
193196

197+
def compute_output_shape(self, input_shape):
198+
return (None, None, 4), (None, None, 1)
199+
194200
def get_config(self):
195201
config = {
196202
"bounding_box_format": self.bounding_box_format,

keras_cv/src/layers/object_detection/roi_sampler.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
@keras.utils.register_keras_serializable(package="keras_cv")
25-
class _ROISampler(keras.layers.Layer):
25+
class ROISampler(keras.layers.Layer):
2626
"""
2727
Sample ROIs for loss related calculation.
2828
@@ -41,9 +41,10 @@ class _ROISampler(keras.layers.Layer):
4141
if its range is [0, num_classes).
4242
4343
Args:
44-
bounding_box_format: The format of bounding boxes to generate. Refer
44+
roi_bounding_box_format: The format of roi bounding boxes. Refer
4545
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/)
4646
for more details on supported bounding box formats.
47+
gt_bounding_box_format: The format of ground truth bounding boxes.
4748
roi_matcher: a `BoxMatcher` object that matches proposals with ground
4849
truth boxes. The positive match must be 1 and negative match must be -1.
4950
Such assumption is not being validated here.
@@ -59,7 +60,8 @@ class _ROISampler(keras.layers.Layer):
5960

6061
def __init__(
6162
self,
62-
bounding_box_format: str,
63+
roi_bounding_box_format: str,
64+
gt_bounding_box_format: str,
6365
roi_matcher: box_matcher.BoxMatcher,
6466
positive_fraction: float = 0.25,
6567
background_class: int = 0,
@@ -68,12 +70,14 @@ def __init__(
6870
**kwargs,
6971
):
7072
super().__init__(**kwargs)
71-
self.bounding_box_format = bounding_box_format
73+
self.roi_bounding_box_format = roi_bounding_box_format
74+
self.gt_bounding_box_format = gt_bounding_box_format
7275
self.roi_matcher = roi_matcher
7376
self.positive_fraction = positive_fraction
7477
self.background_class = background_class
7578
self.num_sampled_rois = num_sampled_rois
7679
self.append_gt_boxes = append_gt_boxes
80+
self.seed_generator = keras.random.SeedGenerator()
7781
self.built = True
7882
# for debugging.
7983
self._positives = keras.metrics.Mean()
@@ -97,6 +101,12 @@ def call(
97101
sampled_gt_classes: [batch_size, num_sampled_rois, 1]
98102
sampled_class_weights: [batch_size, num_sampled_rois, 1]
99103
"""
104+
rois = bounding_box.convert_format(
105+
rois, source=self.roi_bounding_box_format, target="yxyx"
106+
)
107+
gt_boxes = bounding_box.convert_format(
108+
gt_boxes, source=self.gt_bounding_box_format, target="yxyx"
109+
)
100110
if self.append_gt_boxes:
101111
# num_rois += num_gt
102112
rois = ops.concatenate([rois, gt_boxes], axis=1)
@@ -110,12 +120,6 @@ def call(
110120
"num_rois must be less than `num_sampled_rois` "
111121
f"({self.num_sampled_rois}), got {num_rois}"
112122
)
113-
rois = bounding_box.convert_format(
114-
rois, source=self.bounding_box_format, target="yxyx"
115-
)
116-
gt_boxes = bounding_box.convert_format(
117-
gt_boxes, source=self.bounding_box_format, target="yxyx"
118-
)
119123
# [batch_size, num_rois, num_gt]
120124
similarity_mat = iou.compute_iou(
121125
rois, gt_boxes, bounding_box_format="yxyx", use_masking=True
@@ -171,6 +175,7 @@ def call(
171175
negative_matches,
172176
self.num_sampled_rois,
173177
self.positive_fraction,
178+
seed=self.seed_generator,
174179
)
175180
# [batch_size, num_sampled_rois] in the range of [0, num_rois)
176181
sampled_indicators, sampled_indices = ops.top_k(
@@ -204,16 +209,15 @@ def call(
204209
)
205210

206211
def get_config(self):
207-
config = {
208-
"bounding_box_format": self.bounding_box_format,
209-
"positive_fraction": self.positive_fraction,
210-
"background_class": self.background_class,
211-
"num_sampled_rois": self.num_sampled_rois,
212-
"append_gt_boxes": self.append_gt_boxes,
213-
"roi_matcher": self.roi_matcher.get_config(),
214-
}
215-
base_config = super().get_config()
216-
return dict(list(base_config.items()) + list(config.items()))
212+
config = super().get_config()
213+
config["roi_bounding_box_format"] = self.roi_bounding_box_format
214+
config["gt_bounding_box_format"] = self.gt_bounding_box_format
215+
config["positive_fraction"] = self.positive_fraction
216+
config["background_class"] = self.background_class
217+
config["num_sampled_rois"] = self.num_sampled_rois
218+
config["append_gt_boxes"] = self.append_gt_boxes
219+
config["roi_matcher"] = self.roi_matcher.get_config()
220+
return config
217221

218222
@classmethod
219223
def from_config(cls, config, custom_objects=None):

0 commit comments

Comments
 (0)