diff --git a/configs/rec/crnn/README.md b/configs/rec/crnn/README.md
index 950ade180..52c3a0298 100644
--- a/configs/rec/crnn/README.md
+++ b/configs/rec/crnn/README.md
@@ -39,19 +39,21 @@ According to our experiments, the training (following the steps in [Model Traini
-| **Model** | **Context** | **Backbone** | **Train Dataset** | **Model Params** | **Batch size per card** | **Graph train 8P (s/epoch)** | **Graph train 8P (ms/step)** | **Graph train 8P (FPS)** | **Avg Eval Accuracy** | **Recipe** | **Download** |
-| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
-| CRNN | D910x8-MS1.8-G | VGG7 | MJ+ST | 8.72 M | 16 | 2488.82 | 22.06 | 5802.71 | 82.03% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c-573dbd61.mindir) |
-| CRNN | D910x8-MS1.8-G | ResNet34_vd | MJ+ST | 24.48 M | 64 | 2157.18 | 76.48 | 6694.84 | 84.45% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07-eb10a0c9.mindir) |
+| **Model** | **Context** | **Backbone** | **Train Dataset** | **Num Classes** | **Model Params** | **Batch size per card** | **Graph train (s/epoch)** | **Graph train (ms/step)** | **Graph train (FPS)** | **Avg Eval Accuracy** | **Recipe** | **Download** |
+| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
+| CRNN | D910x8-MS1.8-G | VGG7 | MJ+ST | 37 | 8.72 M | 16 | 2488.82 | 22.06 | 5802.71 | 82.03% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c-573dbd61.mindir) |
+| CRNN | D910x8-MS1.8-G | ResNet34_vd | MJ+ST | 37 | 24.48 M | 64 | 2157.18 | 76.48 | 6694.84 | 84.45% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07-eb10a0c9.mindir) |
+| CRNN | D910x4-MS2.0-G | ResNet34_vd | MJ+ST | 96 | 24.51 M | 64 | 4292.18 | 76.08 | 3364.72 | 83.50% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34_server.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_server-e0d66c0c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_server-e0d66c0c-55748731.mindir) |
- Detailed accuracy results for each benchmark dataset (IC03, IC13, IC15, IIIT, SVT, SVTP, CUTE):
- | **Model** | **Backbone** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **Average** |
- | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: |
- | CRNN | VGG7 | 94.53% | 94.00% | 92.18% | 90.74% | 71.95% | 66.06% | 84.10% | 83.93% | 73.33% | 69.44% | 82.03% |
- | CRNN | ResNet34_vd | 94.42% | 94.23% | 93.35% | 92.02% | 75.92% | 70.15% | 87.73% | 86.40% | 76.28% | 73.96% | 84.45% |
+ | **Model** | **Backbone** | **Num Classes** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **Average** |
+ | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: |
+ | CRNN | VGG7 | 37 | 94.53% | 94.00% | 92.18% | 90.74% | 71.95% | 66.06% | 84.10% | 83.93% | 73.33% | 69.44% | 82.03% |
+ | CRNN | ResNet34_vd | 37 | 94.42% | 94.23% | 93.35% | 92.02% | 75.92% | 70.15% | 87.73% | 86.40% | 76.28% | 73.96% | 84.45% |
+ | CRNN | ResNet34_vd | 96 | 94.65% | 94.70% | 94.28% | 93.20% | 72.5% | 63.94% | 87.63% | 86.09% | 74.42% | 73.61% | 83.50% |
### Inference Perf.
@@ -70,7 +72,7 @@ The inference performance is tested on Mindspore Lite, please take a look at [Mi
**Notes:**
- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G-graph mode or F-pynative mode with ms function. For example, D910x8-MS1.8-G is for training on 8 pieces of Ascend 910 NPU using graph mode based on Minspore version 1.8.
- To reproduce the result on other contexts, please ensure the global batch size is the same.
-- The characters supported by model are lowercase English characters from a to z and numbers from 0 to 9. More explanation on dictionary, please refer to [4. Character Dictionary](#4-character-dictionary).
+- The number of classes of the model is determined by the dictionary used for training. The default dictionary contains lowercase English characters from a to z and digits from 0 to 9. More explanation on dictionary, please refer to [4. Character Dictionary](#4-character-dictionary).
- The models are trained from scratch without any pre-training. For more dataset details of training and evaluation, please refer to [Dataset Download & Dataset Usage](#312-dataset-download) section.
- The input Shapes of MindIR of CRNN_VGG7 and CRNN_ResNet34_vd are both (1, 3, 32, 100).
diff --git a/configs/rec/crnn/README_CN.md b/configs/rec/crnn/README_CN.md
index 77f508807..be16c29f5 100644
--- a/configs/rec/crnn/README_CN.md
+++ b/configs/rec/crnn/README_CN.md
@@ -39,20 +39,22 @@ Table Format:
-| **模型** | **环境配置** | **骨干网络** | **训练集** | **参数量** | **单卡批量** | **图模式8卡训练 (s/epoch)** | **图模式8卡训练 (ms/step)** | **图模式8卡训练 (FPS)** | **平均评估精度** | **配置文件** | **模型权重下载** |
-| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
-| CRNN | D910x8-MS1.8-G | VGG7 | MJ+ST | 8.72 M | 16 | 2488.82 | 22.06 | 5802.71 | 82.03% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c-573dbd61.mindir) |
-| CRNN | D910x8-MS1.8-G | ResNet34_vd | MJ+ST | 24.48 M | 64 | 2157.18 | 76.48 | 6694.84 | 84.45% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07-eb10a0c9.mindir) |
+| **模型** | **环境配置** | **骨干网络** | **训练集** | **类别数** | **参数量** | **单卡批量** | **图模式训练 (s/epoch)** | **图模式训练 (ms/step)** | **图模式训练 (FPS)** | **平均评估精度** | **配置文件** | **模型权重下载** |
+| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :------: |
+| CRNN | D910x8-MS1.8-G | VGG7 | MJ+ST | 37 |8.72 M | 16 | 2488.82 | 22.06 | 5802.71 | 82.03% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c-573dbd61.mindir) |
+| CRNN | D910x8-MS1.8-G | ResNet34_vd | MJ+ST | 37 | 24.48 M | 64 | 2157.18 | 76.48 | 6694.84 | 84.45% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07-eb10a0c9.mindir) |
+| CRNN | D910x4-MS2.0-G | ResNet34_vd | MJ+ST | 96 | 24.51 M | 64 | 4292.18 | 76.08 | 3364.72 | 83.50% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34_server.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_server-e0d66c0c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_server-e0d66c0c-55748731.mindir) |
- 在各个基准数据集(IC03,IC13,IC15,IIIT,SVT,SVTP,CUTE)上的准确率:
- | **模型** | **骨干网络** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **平均准确率** |
- | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: |
- | CRNN | VGG7 | 94.53% | 94.00% | 92.18% | 90.74% | 71.95% | 66.06% | 84.10% | 83.93% | 73.33% | 69.44% | 82.03% |
- | CRNN | ResNet34_vd | 94.42% | 94.23% | 93.35% | 92.02% | 75.92% | 70.15% | 87.73% | 86.40% | 76.28% | 73.96% | 84.45% |
+ | **模型** | **骨干网络** | **类别数** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **平均准确率** |
+ | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: |
+ | CRNN | VGG7 | 37 | 94.53% | 94.00% | 92.18% | 90.74% | 71.95% | 66.06% | 84.10% | 83.93% | 73.33% | 69.44% | 82.03% |
+ | CRNN | ResNet34_vd | 37 |94.42% | 94.23% | 93.35% | 92.02% | 75.92% | 70.15% | 87.73% | 86.40% | 76.28% | 73.96% | 84.45% |
+ | CRNN | ResNet34_vd | 96 | 94.65% | 94.70% | 94.28% | 93.20% | 72.5% | 63.94% | 87.63% | 86.09% | 74.42% | 73.61% | 83.50% |
@@ -72,7 +74,7 @@ Table Format:
**注意:**
- 环境配置:训练的环境配置表示为 {处理器}x{处理器数量}-{MS模式},其中 Mindspore 模式可以是 G-graph 模式或 F-pynative 模式。例如,D910x8-MS1.8-G 用于使用图形模式在8张昇腾910 NPU上依赖Mindspore1.8版本进行训练。
- 如需在其他环境配置重现训练结果,请确保全局批量大小与原配置文件保持一致。
-- 模型所能识别的字符都是默认的设置,即所有英文小写字母a至z及数字0至9,详细请看[4. 字符词典](#4-字符词典)
+- 模型的类别数由用于训练的字典决定。默认字典包含小写英文字符从a到z和数字从0到9,详细请看[4. 字符词典](#4-字符词典)
- 模型都是从头开始训练的,无需任何预训练。关于训练和测试数据集的详细介绍,请参考[数据集下载及使用](#312-数据集下载)章节。
- CRNN_VGG7和CRNN_ResNet34_vd的MindIR导出时的输入Shape均为(1, 3, 32, 100)。
diff --git a/configs/rec/crnn/crnn_icdar15.yaml b/configs/rec/crnn/crnn_icdar15.yaml
index 18139f435..358a1f31b 100644
--- a/configs/rec/crnn/crnn_icdar15.yaml
+++ b/configs/rec/crnn/crnn_icdar15.yaml
@@ -96,16 +96,12 @@ train:
character_dict_path: *character_dict_path
use_space_char: *use_space_char
lower: True
- - RecResizeImg: # different from paddle (paddle converts image from HWC to CHW and rescale to [-1, 1] after resize.
+ - RecResizeNormImg:
image_shape: [32, 100] # H, W
infer_mode: *infer_mode
character_dict_path: *character_dict_path
padding: False # aspect ratio will be preserved if true.
- - NormalizeImage: # different from paddle (paddle wrongly normalize BGR image with RGB mean/std from ImageNet for det, and simple rescale to [-1, 1] in rec.
- bgr_to_rgb: True
- is_hwc: True
- mean : [127.0, 127.0, 127.0]
- std : [127.0, 127.0, 127.0]
+ norm_before_pad: False
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
output_columns: ['image', 'text_seq'] #, 'length'] #'img_path']
diff --git a/configs/rec/crnn/crnn_resnet34.yaml b/configs/rec/crnn/crnn_resnet34.yaml
index 1325467c1..bc37c7ea5 100644
--- a/configs/rec/crnn/crnn_resnet34.yaml
+++ b/configs/rec/crnn/crnn_resnet34.yaml
@@ -80,23 +80,19 @@ train:
shuffle: True
transform_pipeline:
- DecodeImage:
- img_mode: BGR
+ img_mode: RGB
to_float32: False
- RecCTCLabelEncode:
max_text_len: *max_text_len
character_dict_path: *character_dict_path
use_space_char: *use_space_char
lower: True
- - RecResizeImg: # different from paddle (paddle converts image from HWC to CHW and rescale to [-1, 1] after resize.
+ - RecResizeNormImg:
image_shape: [32, 100] # H, W
infer_mode: *infer_mode
character_dict_path: *character_dict_path
padding: False # aspect ratio will be preserved if true.
- - NormalizeImage: # different from paddle (paddle wrongly normalize BGR image with RGB mean/std from ImageNet for det, and simple rescale to [-1, 1] in rec.
- bgr_to_rgb: True
- is_hwc: True
- mean : [127.0, 127.0, 127.0]
- std : [127.0, 127.0, 127.0]
+ norm_before_pad: False
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
output_columns: ['image', 'text_seq'] #, 'length'] #'img_path']
diff --git a/configs/rec/crnn/crnn_resnet34_ch.yaml b/configs/rec/crnn/crnn_resnet34_ch.yaml
index bf954cbae..6465a7cd4 100644
--- a/configs/rec/crnn/crnn_resnet34_ch.yaml
+++ b/configs/rec/crnn/crnn_resnet34_ch.yaml
@@ -84,7 +84,7 @@ train:
max_text_len: *max_text_len
transform_pipeline:
- DecodeImage:
- img_mode: BGR
+ img_mode: RGB
to_float32: False
- RecCTCLabelEncode:
max_text_len: *max_text_len
@@ -94,16 +94,12 @@ train:
- Rotate90IfVertical:
threshold: 2.0
direction: counterclockwise
- - RecResizeImg:
- image_shape: [32, 320]
+ - RecResizeNormImg:
+ image_shape: [32, 320] # H, W
infer_mode: *infer_mode
character_dict_path: *character_dict_path
- padding: True
- - NormalizeImage:
- bgr_to_rgb: True
- is_hwc: True
- mean: [127.0, 127.0, 127.0]
- std: [127.0, 127.0, 127.0]
+ padding: True # aspect ratio will be preserved if true.
+ norm_before_pad: False
- ToCHWImage:
output_columns: ["image", "text_seq"]
net_input_column_index: [0]
diff --git a/configs/rec/crnn/crnn_resnet34_server.yaml b/configs/rec/crnn/crnn_resnet34_server.yaml
new file mode 100644
index 000000000..7518981ea
--- /dev/null
+++ b/configs/rec/crnn/crnn_resnet34_server.yaml
@@ -0,0 +1,150 @@
+system:
+ mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
+ distribute: True
+ amp_level: 'O3'
+ seed: 42
+ log_interval: 100
+ val_while_train: True
+ drop_overflow_update: False
+
+common:
+ character_dict_path: &character_dict_path mindocr/utils/dict/en_dict.txt
+ num_classes: &num_classes 96 # num_chars_in_dict+1, TODO: retreive it from dict or check correctness
+ max_text_len: &max_text_len 24
+ infer_mode: &infer_mode False
+ use_space_char: &use_space_char True
+ lower: &lower False
+ batch_size: &batch_size 64
+
+model:
+ type: rec
+ transform: null
+ backbone:
+ name: rec_resnet34
+ pretrained: False
+ neck:
+ name: RNNEncoder
+ hidden_size: 256
+ head:
+ name: CTCHead
+ weight_init: crnn_customised
+ bias_init: crnn_customised
+ out_channels: *num_classes
+
+postprocess:
+ name: RecCTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+metric:
+ name: RecMetric
+ main_indicator: acc
+ character_dict_path: *character_dict_path
+ ignore_space: True
+ print_flag: False
+
+loss:
+ name: CTCLoss
+ pred_seq_len: 25 # TODO: retrieve from the network output shape.
+ max_label_len: *max_text_len # this value should be smaller than pre_seq_len
+ batch_size: *batch_size
+
+scheduler:
+ scheduler: warmup_cosine_decay
+ min_lr: 0.000001
+ lr: 0.001
+ num_epochs: 30
+ warmup_epochs: 2
+ decay_epochs: 28
+
+optimizer:
+ opt: adamw
+ filter_bias_and_bn: True
+ momentum: 0.95
+ weight_decay: 0.0001
+ nesterov: False
+
+loss_scaler:
+ type: dynamic
+ loss_scale: 512
+ scale_factor: 2.0
+ scale_window: 1000
+
+train:
+ ckpt_save_dir: './crnn_resnet34_server'
+ pred_cast_fp32: False # let CTCLoss cast internally
+ ema: True # added
+ dataset_sink_mode: False
+ dataset:
+ type: LMDBDataset
+ dataset_root: /path/to/data_lmdb_release/
+ data_dir: training/
+ # label_file: # not required when using LMDBDataset
+ sample_ratio: 1.0
+ shuffle: True
+ transform_pipeline:
+ - DecodeImage:
+ img_mode: RGB
+ to_float32: False
+ - RecCTCLabelEncode:
+ max_text_len: *max_text_len
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ lower: *lower
+ - RecResizeNormImg:
+ image_shape: [32, 100] # H, W
+ infer_mode: *infer_mode
+ character_dict_path: *character_dict_path
+ padding: True # aspect ratio will be preserved if true.
+ norm_before_pad: True
+ - ToCHWImage:
+ # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
+ output_columns: ['image', 'text_seq'] #, 'length'] #'img_path']
+ net_input_column_index: [0] # input indices for network forward func in output_columns
+ label_column_index: [1] # input indices marked as label
+ #keys_for_loss: 4 # num labels for loss func
+
+ loader:
+ shuffle: True
+ batch_size: *batch_size
+ drop_remainder: True
+ max_rowsize: 12
+ num_workers: 8
+
+eval:
+ ckpt_load_path: ./crnn_resnet34_server/best.ckpt
+ dataset_sink_mode: False
+ dataset:
+ type: LMDBDataset
+ dataset_root: /path/to/data_lmdb_release/
+ data_dir: validation/
+ # label_file: # not required when using LMDBDataset
+ sample_ratio: 1.0
+ shuffle: False
+ transform_pipeline:
+ - DecodeImage:
+ img_mode: RGB
+ to_float32: False
+ - RecCTCLabelEncode:
+ max_text_len: *max_text_len
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ lower: *lower
+ - RecResizeNormImg:
+ image_shape: [32, 100] # H, W
+ infer_mode: *infer_mode
+ character_dict_path: *character_dict_path
+ padding: True # aspect ratio will be preserved if true.
+ norm_before_pad: True
+ - ToCHWImage:
+ # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
+ output_columns: ['image', 'text_padded', 'text_length'] # TODO return text string padding w/ fixed length, and a scaler to indicate the length
+ net_input_column_index: [0] # input indices for network forward func in output_columns
+ label_column_index: [1, 2] # input indices marked as label
+
+ loader:
+ shuffle: False # TODO: tbc
+ batch_size: 64
+ drop_remainder: False
+ max_rowsize: 12
+ num_workers: 8
diff --git a/configs/rec/crnn/crnn_vgg7.yaml b/configs/rec/crnn/crnn_vgg7.yaml
index 5647e3421..a5a750463 100644
--- a/configs/rec/crnn/crnn_vgg7.yaml
+++ b/configs/rec/crnn/crnn_vgg7.yaml
@@ -81,23 +81,19 @@ train:
shuffle: True
transform_pipeline:
- DecodeImage:
- img_mode: BGR
+ img_mode: RGB
to_float32: False
- RecCTCLabelEncode:
max_text_len: *max_text_len
character_dict_path: *character_dict_path
use_space_char: *use_space_char
lower: True
- - RecResizeImg: # different from paddle (paddle converts image from HWC to CHW and rescale to [-1, 1] after resize.
+ - RecResizeNormImg:
image_shape: [32, 100] # H, W
infer_mode: *infer_mode
character_dict_path: *character_dict_path
padding: False # aspect ratio will be preserved if true.
- - NormalizeImage: # different from paddle (paddle wrongly normalize BGR image with RGB mean/std from ImageNet for det, and simple rescale to [-1, 1] in rec.
- bgr_to_rgb: True
- is_hwc: True
- mean : [127.0, 127.0, 127.0]
- std : [127.0, 127.0, 127.0]
+ norm_before_pad: False
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
output_columns: ['image', 'text_seq'] #, 'length'] #'img_path']
diff --git a/configs/rec/rare/rare_resnet34.yaml b/configs/rec/rare/rare_resnet34.yaml
index d910b7c21..85609b5ea 100644
--- a/configs/rec/rare/rare_resnet34.yaml
+++ b/configs/rec/rare/rare_resnet34.yaml
@@ -83,16 +83,12 @@ train:
character_dict_path: *character_dict_path
use_space_char: *use_space_char
lower: True
- - RecResizeImg: # different from paddle (paddle converts image from HWC to CHW and rescale to [-1, 1] after resize.
+ - RecResizeNormImg:
image_shape: [32, 100] # H, W
infer_mode: *infer_mode
character_dict_path: *character_dict_path
padding: False # aspect ratio will be preserved if true.
- - NormalizeImage: # different from paddle (paddle wrongly normalize BGR image with RGB mean/std from ImageNet for det, and simple rescale to [-1, 1] in rec.
- bgr_to_rgb: True
- is_hwc: True
- mean: [127.0, 127.0, 127.0]
- std: [127.0, 127.0, 127.0]
+ norm_before_pad: False
- ToCHWImage:
output_columns: ["image", "text_seq"]
net_input_column_index: [0, 1] # input indices for network forward func in output_columns
diff --git a/configs/rec/rare/rare_resnet34_ch.yaml b/configs/rec/rare/rare_resnet34_ch.yaml
index 624c70b3d..5bd8fd705 100644
--- a/configs/rec/rare/rare_resnet34_ch.yaml
+++ b/configs/rec/rare/rare_resnet34_ch.yaml
@@ -93,16 +93,12 @@ train:
- Rotate90IfVertical:
threshold: 2.0
direction: counterclockwise
- - RecResizeImg:
- image_shape: [32, 320]
+ - RecResizeNormImg:
+ image_shape: [32, 320] # H, W
infer_mode: *infer_mode
character_dict_path: *character_dict_path
- padding: True
- - NormalizeImage:
- bgr_to_rgb: True
- is_hwc: True
- mean: [127.0, 127.0, 127.0]
- std: [127.0, 127.0, 127.0]
+ padding: True # aspect ratio will be preserved if true.
+ norm_before_pad: False
- ToCHWImage:
output_columns: ["image", "text_seq"]
net_input_column_index: [0, 1]
diff --git a/mindocr/data/transforms/rec_transforms.py b/mindocr/data/transforms/rec_transforms.py
index 875234bb6..e97f3e824 100644
--- a/mindocr/data/transforms/rec_transforms.py
+++ b/mindocr/data/transforms/rec_transforms.py
@@ -13,6 +13,7 @@
"RecAttnLabelEncode",
"RecMasterLabelEncode",
"RecResizeImg",
+ "RecResizeNormImg",
"RecResizeNormForInfer",
"SVTRRecResizeImg",
"Rotate90IfVertical",
@@ -356,13 +357,25 @@ def str2idx(
# TODO: reorganize the code for different resize transformation in rec task
-def resize_norm_img(img, image_shape, padding=True, interpolation=cv2.INTER_LINEAR):
+def resize_norm_img(
+ img,
+ image_shape,
+ padding=True,
+ norm_before_pad=False,
+ mean=[127.0, 127.0, 127.0],
+ std=[127.0, 127.0, 127.0],
+ interpolation=cv2.INTER_LINEAR,
+):
"""
resize image
Args:
img: shape (H, W, C)
image_shape: image shape after resize, in (C, H, W)
- padding: if Ture, resize while preserving the H/W ratio, then pad the blank.
+ padding (bool): if Ture, resize while preserving the H/W ratio, then pad the blank.
+ norm_before_pad (bool): if True, normalize the image array before padding.
+ mean: shape (3), mean value for normalization.
+ std: shape (3), std value for normalization.
+ interpolation: image interpolation mode.
"""
imgH, imgW = image_shape
@@ -380,22 +393,51 @@ def resize_norm_img(img, image_shape, padding=True, interpolation=cv2.INTER_LINE
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
- padding_im = np.zeros((imgH, imgW, c), dtype=resized_image.dtype)
- padding_im[:, 0:resized_w, :] = resized_image
valid_ratio = min(1.0, float(resized_w / imgW))
- return padding_im, valid_ratio
+
+ if padding:
+ if norm_before_pad:
+ resized_image = (resized_image - mean) / std
+
+ padded_img = np.zeros((imgH, imgW, c), dtype=resized_image.dtype)
+ padded_img[:, 0:resized_w, :] = resized_image
+
+ if not norm_before_pad:
+ padded_img = (padded_img - mean) / std
+
+ return padded_img, valid_ratio
+ else:
+ resized_image = (resized_image - mean) / std
+ return resized_image, valid_ratio
# TODO: check diff from resize_norm_img
-def resize_norm_img_chinese(img, image_shape):
- """adopted from paddle"""
+def resize_norm_img_chinese(
+ img,
+ image_shape,
+ norm_before_pad=False,
+ mean=[127.0, 127.0, 127.0],
+ std=[127.0, 127.0, 127.0],
+ interpolation=cv2.INTER_LINEAR,
+):
+ """
+ resize image with aspect-ratio keeping and padding
+ Args:
+ img: shape (H, W, C)
+ image_shape: image shape after resize, in (C, H, W)
+ norm_before_pad (bool): if True, normalize the image array before padding.
+ mean: shape (3), mean value for normalization.
+ std: shape (3), std value for normalization.
+ interpolation: image interpolation mode.
+
+ """
imgH, imgW = image_shape
# todo: change to 0 and modified image shape
max_wh_ratio = imgW * 1.0 / imgH
h, w = img.shape[0], img.shape[1]
c = img.shape[2]
ratio = w * 1.0 / h
-
+ max_wh_ratio = max(max_wh_ratio, ratio)
imgW = int(imgH * max_wh_ratio)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
@@ -403,36 +445,93 @@ def resize_norm_img_chinese(img, image_shape):
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
- padding_im = np.zeros((imgH, imgW, c), dtype=resized_image.dtype)
- padding_im[:, 0:resized_w, :] = resized_image
valid_ratio = min(1.0, float(resized_w / imgW))
- return padding_im, valid_ratio
+ if norm_before_pad:
+ resized_image = (resized_image - mean) / std
-# TODO: remove infer_mode and character_dict_path if they are not necesary
-class RecResizeImg(object):
+ padded_img = np.zeros((imgH, imgW, c), dtype=resized_image.dtype)
+ padded_img[:, 0:resized_w, :] = resized_image
+
+ if not norm_before_pad:
+ padded_img = (padded_img - mean) / std
+
+ return padded_img, valid_ratio
+
+
+class RecResizeNormImg(object):
"""adopted from paddle
- resize, convert from hwc to chw, rescale pixel value to -1 to 1
+ Resize and normalize image, and pad image if needed.
+
+ Args:
+ image_shape: image shape after resize, in (C, H, W)
+ padding (bool): if Ture, resize while preserving the H/W ratio, then pad the blank.
+ norm_before_pad (bool): if True, normalize the image array before padding.
+ mean: shape (3), mean value for normalization.
+ std: shape (3), std value for normalization.
+ interpolation: image interpolation mode.
+ norm_before_pad: If True, perform normalization before padding \
+ (by doing so, the padding values will beall zero. Good practice.). \
+ Otherwise, per Default: False
"""
- def __init__(self, image_shape, infer_mode=False, character_dict_path=None, padding=True, **kwargs):
+ def __init__(
+ self,
+ image_shape,
+ infer_mode=False,
+ character_dict_path=None,
+ padding=True,
+ norm_before_pad=False,
+ mean=[127.0, 127.0, 127.0],
+ std=[127.0, 127.0, 127.0],
+ **kwargs,
+ ):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_dict_path = character_dict_path
self.padding = padding
+ self.norm_before_pad = norm_before_pad
+ self.mean = np.array(mean, dtype="float32")
+ self.std = np.array(std, dtype="float32")
def __call__(self, data):
img = data["image"]
if self.infer_mode and self.character_dict_path is not None:
- norm_img, valid_ratio = resize_norm_img_chinese(img, self.image_shape)
+ norm_img, valid_ratio = resize_norm_img_chinese(
+ img, self.image_shape, self.norm_before_pad, self.mean, self.std
+ )
else:
- norm_img, valid_ratio = resize_norm_img(img, self.image_shape, self.padding)
+ norm_img, valid_ratio = resize_norm_img(
+ img,
+ self.image_shape,
+ self.padding,
+ self.norm_before_pad,
+ self.mean,
+ self.std,
+ )
data["image"] = norm_img
data["valid_ratio"] = valid_ratio
- # TODO: data['shape_list'] = ?
return data
+# TODO: remove infer_mode and character_dict_path if they are not necesary
+class RecResizeImg(RecResizeNormImg):
+ """
+ This is to make compatible with older version code that uses RecResizeImg, which is to be updated.
+ """
+
+ def __init__(self, image_shape, infer_mode=False, character_dict_path=None, padding=True, **kwargs):
+ super().__init__(
+ image_shape,
+ infer_mode,
+ character_dict_path,
+ padding,
+ norm_befoer_pad=False,
+ mean=[0.0, 0.0, 0.0],
+ std=[1.0, 1.0, 1.0],
+ )
+
+
class SVTRRecResizeImg(object):
def __init__(self, image_shape, padding=True, **kwargs):
self.image_shape = image_shape
@@ -511,9 +610,7 @@ def __call__(self, data):
# TODO: norm before padding
- data["shape_list"] = np.array(
- [h, w, resize_h / h, resize_w / w], dtype=np.float32
- ) # TODO: reformat, currently align to det
+ data["shape_list"] = [h, w, resize_h / h, resize_w / w] # TODO: reformat, currently align to det
if self.norm_before_pad:
resized_img = self.norm(resized_img)