Skip to content

Commit 258da17

Browse files
committed
Add the function of concatenating to crops after detection.
1 parent bfb030d commit 258da17

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

deploy/py_infer/src/infer_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ def get_args():
119119
"--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring."
120120
)
121121
parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.")
122+
parser.add_argument(
123+
"--is_concat", type=str2bool, default=False, help="Whether to concatenate crops after the detection."
124+
)
122125

123126
args = parser.parse_args()
124127
setup_logger(args)

deploy/py_infer/src/parallel/module/detection/det_post_node.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import cv2
12
import numpy as np
23

34
from ....data_process.utils import cv_utils
@@ -10,19 +11,44 @@ def __init__(self, args, msg_queue):
1011
super(DetPostNode, self).__init__(args, msg_queue)
1112
self.text_detector = None
1213
self.task_type = self.args.task_type
14+
self.is_concat = self.args.is_concat
1315

1416
def init_self_args(self):
1517
self.text_detector = TextDetector(self.args)
1618
self.text_detector.init(preprocess=False, model=False, postprocess=True)
1719
super().init_self_args()
1820

21+
def concat_crops(self, crops: list):
22+
"""
23+
Concatenates the list of cropped images horizontally after resizing them to have the same height.
24+
25+
Args:
26+
crops (list): A list of cropped images represented as numpy arrays.
27+
28+
Returns:
29+
numpy.ndarray: A horizontally concatenated image array.
30+
"""
31+
max_height = max(crop.shape[0] for crop in crops)
32+
resized_crops = []
33+
for crop in crops:
34+
h, w, c = crop.shape
35+
new_h = max_height
36+
new_w = int((w / h) * new_h)
37+
38+
resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
39+
resized_crops.append(resized_img)
40+
crops_concated = np.concatenate(resized_crops, axis=1)
41+
return crops_concated
42+
1943
def process(self, input_data):
2044
if input_data.skip:
2145
self.send_to_next_module(input_data)
2246
return
2347

2448
data = input_data.data
2549
boxes = self.text_detector.postprocess(data["pred"], data["shape_list"])
50+
if self.is_concat:
51+
boxes = sorted(boxes, key=lambda points: (points[0][1], points[0][0]))
2652

2753
infer_res_list = []
2854
for box in boxes:
@@ -39,6 +65,8 @@ def process(self, input_data):
3965
for box in infer_res_list:
4066
sub_image = cv_utils.crop_box_from_image(image, np.array(box))
4167
sub_image_list.append(sub_image)
68+
if self.is_concat:
69+
sub_image_list = len(sub_image_list) * [self.concat_crops(sub_image_list)]
4270
input_data.sub_image_list = sub_image_list
4371

4472
input_data.data = None

deploy/py_infer/src/parallel/module/recognition/rec_post_node.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def process(self, input_data):
2828
else:
2929
texts = output["texts"]
3030
confs = output["confs"]
31-
for result, text, conf in zip(input_data.infer_result, texts, confs):
32-
result.append(text)
33-
result.append(conf)
31+
for results, text, conf in zip(input_data.infer_result, texts, confs):
32+
results.append(text)
33+
results.append(conf)
3434

3535
input_data.data = None
3636

0 commit comments

Comments
 (0)