Skip to content

Commit d9ed09c

Browse files
(Ian Stenbit)LukeWood
andcommitted
Batch fill rectangle (keras-team#65)
* temporary push benchmark file * temporary push benchmark file * refactor fill_utils.py * refactor * refactor * batch fill_rectangle * small refactor * docstring fix * Fix issue after rebase Co-authored-by: Luke Wood <[email protected]>
1 parent 7b5fc14 commit d9ed09c

File tree

3 files changed

+88
-60
lines changed

3 files changed

+88
-60
lines changed

keras_cv/layers/preprocessing/cut_mix.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,13 @@ def _cutmix(self, images, labels):
120120
lambda_sample = 1.0 - bbox_area / (image_height * image_width)
121121
lambda_sample = tf.cast(lambda_sample, dtype=tf.float32)
122122

123-
images = tf.map_fn(
124-
lambda x: fill_utils.fill_rectangle(*x),
125-
(
126-
images,
127-
random_center_width,
128-
random_center_height,
129-
cut_width // 2,
130-
cut_height // 2,
131-
tf.gather(images, permutation_order),
132-
),
133-
fn_output_signature=tf.TensorSpec.from_tensor(images[0]),
123+
images = fill_utils.fill_rectangle(
124+
images,
125+
random_center_width,
126+
random_center_height,
127+
cut_width,
128+
cut_height,
129+
tf.gather(images, permutation_order),
134130
)
135131

136132
return images, labels, lambda_sample, permutation_order

keras_cv/layers/preprocessing/random_cutout.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class RandomCutout(layers.Layer):
3131
in the range `[20% of image height, 30% of image height]`.
3232
`height_factor=(32, 64)` results in a height picked in the range
3333
[32, 64]. `height_factor=0.2` results in a height of [0%, 20%] of image
34-
height, and `height_factor=32` results in a height of 32.
34+
height, and `height_factor=32` results in a height between [0, 32].
3535
width_factor: One of:
3636
- a positive float representing a fraction of image width
3737
- an integer representing an absolute width
@@ -40,7 +40,7 @@ class RandomCutout(layers.Layer):
4040
in the range `[20% of image width, 30% of image width]`.
4141
`width_factor=(32, 64)` results in a width picked in the range
4242
[32, 64]. `width_factor=0.2` results in a width of [0%, 20%] of image
43-
width, and `width_factor=32` results in a width of 32.
43+
width, and `width_factor=32` results in a width between [0, 32].
4444
fill_mode: Pixels inside the patches are filled according to the given
4545
mode (one of `{"constant", "gaussian_noise"}`).
4646
- *constant*: Pixels are filled with the same constant value.
@@ -151,19 +151,13 @@ def _random_cutout(self, inputs):
151151
center_x, center_y = self._compute_rectangle_position(inputs)
152152
rectangle_height, rectangle_width = self._compute_rectangle_size(inputs)
153153
rectangle_fill = self._compute_rectangle_fill(inputs)
154-
half_height = tf.cast(tf.math.ceil(rectangle_height / 2), tf.int32)
155-
half_width = tf.cast(tf.math.ceil(rectangle_width / 2), tf.int32)
156-
inputs = tf.map_fn(
157-
lambda x: fill_utils.fill_rectangle(*x),
158-
(
159-
inputs,
160-
center_y,
161-
center_x,
162-
half_width,
163-
half_height,
164-
rectangle_fill,
165-
),
166-
fn_output_signature=tf.TensorSpec.from_tensor(inputs[0]),
154+
inputs = fill_utils.fill_rectangle(
155+
inputs,
156+
center_x,
157+
center_y,
158+
rectangle_width,
159+
rectangle_height,
160+
rectangle_fill,
167161
)
168162
return inputs
169163

@@ -177,14 +171,14 @@ def _compute_rectangle_position(self, inputs):
177171
center_x = tf.random.uniform(
178172
shape=[batch_size],
179173
minval=0,
180-
maxval=image_height,
174+
maxval=image_width,
181175
dtype=tf.int32,
182176
seed=self.seed,
183177
)
184178
center_y = tf.random.uniform(
185179
shape=[batch_size],
186180
minval=0,
187-
maxval=image_width,
181+
maxval=image_height,
188182
dtype=tf.int32,
189183
seed=self.seed,
190184
)

keras_cv/utils/fill_utils.py

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,42 +10,80 @@
1010
# distributed under the License is distributed on an "AS IS" BASIS,
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
13-
# limitations under the License.grep -q Copyright $i
13+
# limitations under the License.
1414
import tensorflow as tf
1515

16+
from keras_cv.utils import bbox
1617

17-
def fill_rectangle(
18-
image, center_width, center_height, half_width, half_height, fill=None
19-
):
20-
"""Fill a rectangle in a given image using the value provided in replace.
18+
19+
def rectangle_masks(mask_shape, corners):
20+
"""Computes positional masks of rectangles in images
2121
2222
Args:
23-
image: the starting image to fill the rectangle on.
24-
center_width: the X center of the rectangle to fill
25-
center_height: the Y center of the rectangle to fill
26-
half_width: 1/2 the width of the resulting rectangle
27-
half_height: 1/2 the height of the resulting rectangle
28-
fill: A tensor with same shape as image. Values at rectangle
29-
position are used as fill.
23+
mask_shape: shape of the masks as [batch_size, height, width].
24+
corners: rectangle coordinates in corners format.
25+
3026
Returns:
31-
image: the modified image with the chosen rectangle filled.
27+
boolean masks with True at rectangle positions.
3228
"""
33-
image_shape = tf.shape(image)
34-
image_height = image_shape[0]
35-
image_width = image_shape[1]
36-
37-
lower_pad = tf.maximum(0, center_height - half_height)
38-
upper_pad = tf.maximum(0, image_height - center_height - half_height)
39-
left_pad = tf.maximum(0, center_width - half_width)
40-
right_pad = tf.maximum(0, image_width - center_width - half_width)
41-
42-
shape = [
43-
image_height - (lower_pad + upper_pad),
44-
image_width - (left_pad + right_pad),
45-
]
46-
padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
47-
mask = tf.pad(tf.zeros(shape, dtype=image.dtype), padding_dims, constant_values=1)
48-
mask = tf.expand_dims(mask, -1)
49-
50-
image = tf.where(tf.equal(mask, 0), fill, image)
51-
return image
29+
# add broadcasting axes
30+
corners = corners[..., tf.newaxis, tf.newaxis]
31+
32+
# split coordinates
33+
x0 = corners[:, 0]
34+
y0 = corners[:, 1]
35+
x1 = corners[:, 2]
36+
y1 = corners[:, 3]
37+
38+
# repeat height and width
39+
batch_size, height, width = mask_shape
40+
x0_rep = tf.repeat(x0, height, axis=1)
41+
y0_rep = tf.repeat(y0, width, axis=2)
42+
x1_rep = tf.repeat(x1, height, axis=1)
43+
y1_rep = tf.repeat(y1, width, axis=2)
44+
45+
# range grid
46+
range_row = tf.range(0, height, dtype=corners.dtype)
47+
range_col = tf.range(0, width, dtype=corners.dtype)
48+
range_row = tf.repeat(range_row[tf.newaxis, :, tf.newaxis], batch_size, 0)
49+
range_col = tf.repeat(range_col[tf.newaxis, tf.newaxis, :], batch_size, 0)
50+
51+
# boolean masks
52+
mask_x0 = tf.less_equal(x0_rep, range_col)
53+
mask_y0 = tf.less_equal(y0_rep, range_row)
54+
mask_x1 = tf.less(range_col, x1_rep)
55+
mask_y1 = tf.less(range_row, y1_rep)
56+
57+
masks = mask_x0 & mask_y0 & mask_x1 & mask_y1
58+
59+
return masks
60+
61+
62+
def fill_rectangle(images, center_x, center_y, width, height, fill):
63+
"""Fill rectangles with fill value into images.
64+
65+
Args:
66+
images: Tensor of images to fill rectangles into.
67+
center_x: Tensor of positions of the rectangle centers on the x-axis.
68+
center_y: Tensor f positions of the rectangle centers on the y-axis.
69+
width: Tensor of widths of the rectangles
70+
height: Tensor of heights of the rectangles
71+
fill: Tensor with same shape as images to get rectangle fill from.
72+
Returns:
73+
images with filled rectangles.
74+
"""
75+
images_shape = tf.shape(images)
76+
batch_size = images_shape[0]
77+
images_height = images_shape[1]
78+
images_width = images_shape[2]
79+
80+
xywh = tf.stack([center_x, center_y, width, height], axis=1)
81+
xywh = tf.cast(xywh, tf.float32)
82+
corners = bbox.xywh_to_corners(xywh)
83+
84+
masks_shape = (batch_size, images_height, images_width)
85+
is_patch_mask = rectangle_masks(masks_shape, corners)
86+
is_patch_mask = tf.expand_dims(is_patch_mask, -1)
87+
88+
images = tf.where(tf.equal(is_patch_mask, True), fill, images)
89+
return images

0 commit comments

Comments
 (0)