diff --git a/configs/rec/crnn/README.md b/configs/rec/crnn/README.md index 950ade180..52c3a0298 100644 --- a/configs/rec/crnn/README.md +++ b/configs/rec/crnn/README.md @@ -39,19 +39,21 @@ According to our experiments, the training (following the steps in [Model Traini
-| **Model** | **Context** | **Backbone** | **Train Dataset** | **Model Params** | **Batch size per card** | **Graph train 8P (s/epoch)** | **Graph train 8P (ms/step)** | **Graph train 8P (FPS)** | **Avg Eval Accuracy** | **Recipe** | **Download** | -| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | -| CRNN | D910x8-MS1.8-G | VGG7 | MJ+ST | 8.72 M | 16 | 2488.82 | 22.06 | 5802.71 | 82.03% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c-573dbd61.mindir) | -| CRNN | D910x8-MS1.8-G | ResNet34_vd | MJ+ST | 24.48 M | 64 | 2157.18 | 76.48 | 6694.84 | 84.45% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07-eb10a0c9.mindir) | +| **Model** | **Context** | **Backbone** | **Train Dataset** | **Num Classes** | **Model Params** | **Batch size per card** | **Graph train (s/epoch)** | **Graph train (ms/step)** | **Graph train (FPS)** | **Avg Eval Accuracy** | **Recipe** | **Download** | +| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | +| CRNN | D910x8-MS1.8-G | VGG7 | MJ+ST | 37 | 8.72 M | 16 | 2488.82 | 22.06 | 5802.71 | 82.03% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c-573dbd61.mindir) | +| CRNN | D910x8-MS1.8-G | ResNet34_vd | MJ+ST | 37 | 24.48 M | 64 | 2157.18 | 76.48 | 6694.84 | 84.45% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07-eb10a0c9.mindir) | +| CRNN | D910x4-MS2.0-G | ResNet34_vd | MJ+ST | 96 | 24.51 M | 64 | 4292.18 | 76.08 | 3364.72 | 83.50% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34_server.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_server-e0d66c0c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_server-e0d66c0c-55748731.mindir) |
- Detailed accuracy results for each benchmark dataset (IC03, IC13, IC15, IIIT, SVT, SVTP, CUTE):
- | **Model** | **Backbone** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **Average** | - | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | - | CRNN | VGG7 | 94.53% | 94.00% | 92.18% | 90.74% | 71.95% | 66.06% | 84.10% | 83.93% | 73.33% | 69.44% | 82.03% | - | CRNN | ResNet34_vd | 94.42% | 94.23% | 93.35% | 92.02% | 75.92% | 70.15% | 87.73% | 86.40% | 76.28% | 73.96% | 84.45% | + | **Model** | **Backbone** | **Num Classes** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **Average** | + | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | + | CRNN | VGG7 | 37 | 94.53% | 94.00% | 92.18% | 90.74% | 71.95% | 66.06% | 84.10% | 83.93% | 73.33% | 69.44% | 82.03% | + | CRNN | ResNet34_vd | 37 | 94.42% | 94.23% | 93.35% | 92.02% | 75.92% | 70.15% | 87.73% | 86.40% | 76.28% | 73.96% | 84.45% | + | CRNN | ResNet34_vd | 96 | 94.65% | 94.70% | 94.28% | 93.20% | 72.5% | 63.94% | 87.63% | 86.09% | 74.42% | 73.61% | 83.50% |
### Inference Perf. @@ -70,7 +72,7 @@ The inference performance is tested on Mindspore Lite, please take a look at [Mi **Notes:** - Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G-graph mode or F-pynative mode with ms function. For example, D910x8-MS1.8-G is for training on 8 pieces of Ascend 910 NPU using graph mode based on Minspore version 1.8. - To reproduce the result on other contexts, please ensure the global batch size is the same. -- The characters supported by model are lowercase English characters from a to z and numbers from 0 to 9. More explanation on dictionary, please refer to [4. Character Dictionary](#4-character-dictionary). +- The number of classes of the model is determined by the dictionary used for training. The default dictionary contains lowercase English characters from a to z and digits from 0 to 9. More explanation on dictionary, please refer to [4. Character Dictionary](#4-character-dictionary). - The models are trained from scratch without any pre-training. For more dataset details of training and evaluation, please refer to [Dataset Download & Dataset Usage](#312-dataset-download) section. - The input Shapes of MindIR of CRNN_VGG7 and CRNN_ResNet34_vd are both (1, 3, 32, 100). diff --git a/configs/rec/crnn/README_CN.md b/configs/rec/crnn/README_CN.md index 77f508807..be16c29f5 100644 --- a/configs/rec/crnn/README_CN.md +++ b/configs/rec/crnn/README_CN.md @@ -39,20 +39,22 @@ Table Format:
-| **模型** | **环境配置** | **骨干网络** | **训练集** | **参数量** | **单卡批量** | **图模式8卡训练 (s/epoch)** | **图模式8卡训练 (ms/step)** | **图模式8卡训练 (FPS)** | **平均评估精度** | **配置文件** | **模型权重下载** | -| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | -| CRNN | D910x8-MS1.8-G | VGG7 | MJ+ST | 8.72 M | 16 | 2488.82 | 22.06 | 5802.71 | 82.03% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c-573dbd61.mindir) | -| CRNN | D910x8-MS1.8-G | ResNet34_vd | MJ+ST | 24.48 M | 64 | 2157.18 | 76.48 | 6694.84 | 84.45% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07-eb10a0c9.mindir) | +| **模型** | **环境配置** | **骨干网络** | **训练集** | **类别数** | **参数量** | **单卡批量** | **图模式训练 (s/epoch)** | **图模式训练 (ms/step)** | **图模式训练 (FPS)** | **平均评估精度** | **配置文件** | **模型权重下载** | +| :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :------: | +| CRNN | D910x8-MS1.8-G | VGG7 | MJ+ST | 37 |8.72 M | 16 | 2488.82 | 22.06 | 5802.71 | 82.03% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_vgg7.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c-573dbd61.mindir) | +| CRNN | D910x8-MS1.8-G | ResNet34_vd | MJ+ST | 37 | 24.48 M | 64 | 2157.18 | 76.48 | 6694.84 | 84.45% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07-eb10a0c9.mindir) | +| CRNN | D910x4-MS2.0-G | ResNet34_vd | MJ+ST | 96 | 24.51 M | 64 | 4292.18 | 76.08 | 3364.72 | 83.50% | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/crnn/crnn_resnet34_server.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_server-e0d66c0c.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34_server-e0d66c0c-55748731.mindir) |
- 在各个基准数据集(IC03,IC13,IC15,IIIT,SVT,SVTP,CUTE)上的准确率:
- | **模型** | **骨干网络** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **平均准确率** | - | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | - | CRNN | VGG7 | 94.53% | 94.00% | 92.18% | 90.74% | 71.95% | 66.06% | 84.10% | 83.93% | 73.33% | 69.44% | 82.03% | - | CRNN | ResNet34_vd | 94.42% | 94.23% | 93.35% | 92.02% | 75.92% | 70.15% | 87.73% | 86.40% | 76.28% | 73.96% | 84.45% | + | **模型** | **骨干网络** | **类别数** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **平均准确率** | + | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | + | CRNN | VGG7 | 37 | 94.53% | 94.00% | 92.18% | 90.74% | 71.95% | 66.06% | 84.10% | 83.93% | 73.33% | 69.44% | 82.03% | + | CRNN | ResNet34_vd | 37 |94.42% | 94.23% | 93.35% | 92.02% | 75.92% | 70.15% | 87.73% | 86.40% | 76.28% | 73.96% | 84.45% | + | CRNN | ResNet34_vd | 96 | 94.65% | 94.70% | 94.28% | 93.20% | 72.5% | 63.94% | 87.63% | 86.09% | 74.42% | 73.61% | 83.50% |
@@ -72,7 +74,7 @@ Table Format: **注意:** - 环境配置:训练的环境配置表示为 {处理器}x{处理器数量}-{MS模式},其中 Mindspore 模式可以是 G-graph 模式或 F-pynative 模式。例如,D910x8-MS1.8-G 用于使用图形模式在8张昇腾910 NPU上依赖Mindspore1.8版本进行训练。 - 如需在其他环境配置重现训练结果,请确保全局批量大小与原配置文件保持一致。 -- 模型所能识别的字符都是默认的设置,即所有英文小写字母a至z及数字0至9,详细请看[4. 字符词典](#4-字符词典) +- 模型的类别数由用于训练的字典决定。默认字典包含小写英文字符从a到z和数字从0到9,详细请看[4. 字符词典](#4-字符词典) - 模型都是从头开始训练的,无需任何预训练。关于训练和测试数据集的详细介绍,请参考[数据集下载及使用](#312-数据集下载)章节。 - CRNN_VGG7和CRNN_ResNet34_vd的MindIR导出时的输入Shape均为(1, 3, 32, 100)。 diff --git a/configs/rec/crnn/crnn_icdar15.yaml b/configs/rec/crnn/crnn_icdar15.yaml index 18139f435..358a1f31b 100644 --- a/configs/rec/crnn/crnn_icdar15.yaml +++ b/configs/rec/crnn/crnn_icdar15.yaml @@ -96,16 +96,12 @@ train: character_dict_path: *character_dict_path use_space_char: *use_space_char lower: True - - RecResizeImg: # different from paddle (paddle converts image from HWC to CHW and rescale to [-1, 1] after resize. + - RecResizeNormImg: image_shape: [32, 100] # H, W infer_mode: *infer_mode character_dict_path: *character_dict_path padding: False # aspect ratio will be preserved if true. - - NormalizeImage: # different from paddle (paddle wrongly normalize BGR image with RGB mean/std from ImageNet for det, and simple rescale to [-1, 1] in rec. - bgr_to_rgb: True - is_hwc: True - mean : [127.0, 127.0, 127.0] - std : [127.0, 127.0, 127.0] + norm_before_pad: False - ToCHWImage: # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize output_columns: ['image', 'text_seq'] #, 'length'] #'img_path'] diff --git a/configs/rec/crnn/crnn_resnet34.yaml b/configs/rec/crnn/crnn_resnet34.yaml index 1325467c1..bc37c7ea5 100644 --- a/configs/rec/crnn/crnn_resnet34.yaml +++ b/configs/rec/crnn/crnn_resnet34.yaml @@ -80,23 +80,19 @@ train: shuffle: True transform_pipeline: - DecodeImage: - img_mode: BGR + img_mode: RGB to_float32: False - RecCTCLabelEncode: max_text_len: *max_text_len character_dict_path: *character_dict_path use_space_char: *use_space_char lower: True - - RecResizeImg: # different from paddle (paddle converts image from HWC to CHW and rescale to [-1, 1] after resize. + - RecResizeNormImg: image_shape: [32, 100] # H, W infer_mode: *infer_mode character_dict_path: *character_dict_path padding: False # aspect ratio will be preserved if true. - - NormalizeImage: # different from paddle (paddle wrongly normalize BGR image with RGB mean/std from ImageNet for det, and simple rescale to [-1, 1] in rec. - bgr_to_rgb: True - is_hwc: True - mean : [127.0, 127.0, 127.0] - std : [127.0, 127.0, 127.0] + norm_before_pad: False - ToCHWImage: # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize output_columns: ['image', 'text_seq'] #, 'length'] #'img_path'] diff --git a/configs/rec/crnn/crnn_resnet34_ch.yaml b/configs/rec/crnn/crnn_resnet34_ch.yaml index bf954cbae..6465a7cd4 100644 --- a/configs/rec/crnn/crnn_resnet34_ch.yaml +++ b/configs/rec/crnn/crnn_resnet34_ch.yaml @@ -84,7 +84,7 @@ train: max_text_len: *max_text_len transform_pipeline: - DecodeImage: - img_mode: BGR + img_mode: RGB to_float32: False - RecCTCLabelEncode: max_text_len: *max_text_len @@ -94,16 +94,12 @@ train: - Rotate90IfVertical: threshold: 2.0 direction: counterclockwise - - RecResizeImg: - image_shape: [32, 320] + - RecResizeNormImg: + image_shape: [32, 320] # H, W infer_mode: *infer_mode character_dict_path: *character_dict_path - padding: True - - NormalizeImage: - bgr_to_rgb: True - is_hwc: True - mean: [127.0, 127.0, 127.0] - std: [127.0, 127.0, 127.0] + padding: True # aspect ratio will be preserved if true. + norm_before_pad: False - ToCHWImage: output_columns: ["image", "text_seq"] net_input_column_index: [0] diff --git a/configs/rec/crnn/crnn_resnet34_server.yaml b/configs/rec/crnn/crnn_resnet34_server.yaml new file mode 100644 index 000000000..7518981ea --- /dev/null +++ b/configs/rec/crnn/crnn_resnet34_server.yaml @@ -0,0 +1,150 @@ +system: + mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore + distribute: True + amp_level: 'O3' + seed: 42 + log_interval: 100 + val_while_train: True + drop_overflow_update: False + +common: + character_dict_path: &character_dict_path mindocr/utils/dict/en_dict.txt + num_classes: &num_classes 96 # num_chars_in_dict+1, TODO: retreive it from dict or check correctness + max_text_len: &max_text_len 24 + infer_mode: &infer_mode False + use_space_char: &use_space_char True + lower: &lower False + batch_size: &batch_size 64 + +model: + type: rec + transform: null + backbone: + name: rec_resnet34 + pretrained: False + neck: + name: RNNEncoder + hidden_size: 256 + head: + name: CTCHead + weight_init: crnn_customised + bias_init: crnn_customised + out_channels: *num_classes + +postprocess: + name: RecCTCLabelDecode + character_dict_path: *character_dict_path + use_space_char: *use_space_char + +metric: + name: RecMetric + main_indicator: acc + character_dict_path: *character_dict_path + ignore_space: True + print_flag: False + +loss: + name: CTCLoss + pred_seq_len: 25 # TODO: retrieve from the network output shape. + max_label_len: *max_text_len # this value should be smaller than pre_seq_len + batch_size: *batch_size + +scheduler: + scheduler: warmup_cosine_decay + min_lr: 0.000001 + lr: 0.001 + num_epochs: 30 + warmup_epochs: 2 + decay_epochs: 28 + +optimizer: + opt: adamw + filter_bias_and_bn: True + momentum: 0.95 + weight_decay: 0.0001 + nesterov: False + +loss_scaler: + type: dynamic + loss_scale: 512 + scale_factor: 2.0 + scale_window: 1000 + +train: + ckpt_save_dir: './crnn_resnet34_server' + pred_cast_fp32: False # let CTCLoss cast internally + ema: True # added + dataset_sink_mode: False + dataset: + type: LMDBDataset + dataset_root: /path/to/data_lmdb_release/ + data_dir: training/ + # label_file: # not required when using LMDBDataset + sample_ratio: 1.0 + shuffle: True + transform_pipeline: + - DecodeImage: + img_mode: RGB + to_float32: False + - RecCTCLabelEncode: + max_text_len: *max_text_len + character_dict_path: *character_dict_path + use_space_char: *use_space_char + lower: *lower + - RecResizeNormImg: + image_shape: [32, 100] # H, W + infer_mode: *infer_mode + character_dict_path: *character_dict_path + padding: True # aspect ratio will be preserved if true. + norm_before_pad: True + - ToCHWImage: + # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize + output_columns: ['image', 'text_seq'] #, 'length'] #'img_path'] + net_input_column_index: [0] # input indices for network forward func in output_columns + label_column_index: [1] # input indices marked as label + #keys_for_loss: 4 # num labels for loss func + + loader: + shuffle: True + batch_size: *batch_size + drop_remainder: True + max_rowsize: 12 + num_workers: 8 + +eval: + ckpt_load_path: ./crnn_resnet34_server/best.ckpt + dataset_sink_mode: False + dataset: + type: LMDBDataset + dataset_root: /path/to/data_lmdb_release/ + data_dir: validation/ + # label_file: # not required when using LMDBDataset + sample_ratio: 1.0 + shuffle: False + transform_pipeline: + - DecodeImage: + img_mode: RGB + to_float32: False + - RecCTCLabelEncode: + max_text_len: *max_text_len + character_dict_path: *character_dict_path + use_space_char: *use_space_char + lower: *lower + - RecResizeNormImg: + image_shape: [32, 100] # H, W + infer_mode: *infer_mode + character_dict_path: *character_dict_path + padding: True # aspect ratio will be preserved if true. + norm_before_pad: True + - ToCHWImage: + # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize + output_columns: ['image', 'text_padded', 'text_length'] # TODO return text string padding w/ fixed length, and a scaler to indicate the length + net_input_column_index: [0] # input indices for network forward func in output_columns + label_column_index: [1, 2] # input indices marked as label + + loader: + shuffle: False # TODO: tbc + batch_size: 64 + drop_remainder: False + max_rowsize: 12 + num_workers: 8 diff --git a/configs/rec/crnn/crnn_vgg7.yaml b/configs/rec/crnn/crnn_vgg7.yaml index 5647e3421..a5a750463 100644 --- a/configs/rec/crnn/crnn_vgg7.yaml +++ b/configs/rec/crnn/crnn_vgg7.yaml @@ -81,23 +81,19 @@ train: shuffle: True transform_pipeline: - DecodeImage: - img_mode: BGR + img_mode: RGB to_float32: False - RecCTCLabelEncode: max_text_len: *max_text_len character_dict_path: *character_dict_path use_space_char: *use_space_char lower: True - - RecResizeImg: # different from paddle (paddle converts image from HWC to CHW and rescale to [-1, 1] after resize. + - RecResizeNormImg: image_shape: [32, 100] # H, W infer_mode: *infer_mode character_dict_path: *character_dict_path padding: False # aspect ratio will be preserved if true. - - NormalizeImage: # different from paddle (paddle wrongly normalize BGR image with RGB mean/std from ImageNet for det, and simple rescale to [-1, 1] in rec. - bgr_to_rgb: True - is_hwc: True - mean : [127.0, 127.0, 127.0] - std : [127.0, 127.0, 127.0] + norm_before_pad: False - ToCHWImage: # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize output_columns: ['image', 'text_seq'] #, 'length'] #'img_path'] diff --git a/configs/rec/rare/rare_resnet34.yaml b/configs/rec/rare/rare_resnet34.yaml index d910b7c21..85609b5ea 100644 --- a/configs/rec/rare/rare_resnet34.yaml +++ b/configs/rec/rare/rare_resnet34.yaml @@ -83,16 +83,12 @@ train: character_dict_path: *character_dict_path use_space_char: *use_space_char lower: True - - RecResizeImg: # different from paddle (paddle converts image from HWC to CHW and rescale to [-1, 1] after resize. + - RecResizeNormImg: image_shape: [32, 100] # H, W infer_mode: *infer_mode character_dict_path: *character_dict_path padding: False # aspect ratio will be preserved if true. - - NormalizeImage: # different from paddle (paddle wrongly normalize BGR image with RGB mean/std from ImageNet for det, and simple rescale to [-1, 1] in rec. - bgr_to_rgb: True - is_hwc: True - mean: [127.0, 127.0, 127.0] - std: [127.0, 127.0, 127.0] + norm_before_pad: False - ToCHWImage: output_columns: ["image", "text_seq"] net_input_column_index: [0, 1] # input indices for network forward func in output_columns diff --git a/configs/rec/rare/rare_resnet34_ch.yaml b/configs/rec/rare/rare_resnet34_ch.yaml index 624c70b3d..5bd8fd705 100644 --- a/configs/rec/rare/rare_resnet34_ch.yaml +++ b/configs/rec/rare/rare_resnet34_ch.yaml @@ -93,16 +93,12 @@ train: - Rotate90IfVertical: threshold: 2.0 direction: counterclockwise - - RecResizeImg: - image_shape: [32, 320] + - RecResizeNormImg: + image_shape: [32, 320] # H, W infer_mode: *infer_mode character_dict_path: *character_dict_path - padding: True - - NormalizeImage: - bgr_to_rgb: True - is_hwc: True - mean: [127.0, 127.0, 127.0] - std: [127.0, 127.0, 127.0] + padding: True # aspect ratio will be preserved if true. + norm_before_pad: False - ToCHWImage: output_columns: ["image", "text_seq"] net_input_column_index: [0, 1] diff --git a/mindocr/data/transforms/rec_transforms.py b/mindocr/data/transforms/rec_transforms.py index 875234bb6..e97f3e824 100644 --- a/mindocr/data/transforms/rec_transforms.py +++ b/mindocr/data/transforms/rec_transforms.py @@ -13,6 +13,7 @@ "RecAttnLabelEncode", "RecMasterLabelEncode", "RecResizeImg", + "RecResizeNormImg", "RecResizeNormForInfer", "SVTRRecResizeImg", "Rotate90IfVertical", @@ -356,13 +357,25 @@ def str2idx( # TODO: reorganize the code for different resize transformation in rec task -def resize_norm_img(img, image_shape, padding=True, interpolation=cv2.INTER_LINEAR): +def resize_norm_img( + img, + image_shape, + padding=True, + norm_before_pad=False, + mean=[127.0, 127.0, 127.0], + std=[127.0, 127.0, 127.0], + interpolation=cv2.INTER_LINEAR, +): """ resize image Args: img: shape (H, W, C) image_shape: image shape after resize, in (C, H, W) - padding: if Ture, resize while preserving the H/W ratio, then pad the blank. + padding (bool): if Ture, resize while preserving the H/W ratio, then pad the blank. + norm_before_pad (bool): if True, normalize the image array before padding. + mean: shape (3), mean value for normalization. + std: shape (3), std value for normalization. + interpolation: image interpolation mode. """ imgH, imgW = image_shape @@ -380,22 +393,51 @@ def resize_norm_img(img, image_shape, padding=True, interpolation=cv2.INTER_LINE resized_w = int(math.ceil(imgH * ratio)) resized_image = cv2.resize(img, (resized_w, imgH)) - padding_im = np.zeros((imgH, imgW, c), dtype=resized_image.dtype) - padding_im[:, 0:resized_w, :] = resized_image valid_ratio = min(1.0, float(resized_w / imgW)) - return padding_im, valid_ratio + + if padding: + if norm_before_pad: + resized_image = (resized_image - mean) / std + + padded_img = np.zeros((imgH, imgW, c), dtype=resized_image.dtype) + padded_img[:, 0:resized_w, :] = resized_image + + if not norm_before_pad: + padded_img = (padded_img - mean) / std + + return padded_img, valid_ratio + else: + resized_image = (resized_image - mean) / std + return resized_image, valid_ratio # TODO: check diff from resize_norm_img -def resize_norm_img_chinese(img, image_shape): - """adopted from paddle""" +def resize_norm_img_chinese( + img, + image_shape, + norm_before_pad=False, + mean=[127.0, 127.0, 127.0], + std=[127.0, 127.0, 127.0], + interpolation=cv2.INTER_LINEAR, +): + """ + resize image with aspect-ratio keeping and padding + Args: + img: shape (H, W, C) + image_shape: image shape after resize, in (C, H, W) + norm_before_pad (bool): if True, normalize the image array before padding. + mean: shape (3), mean value for normalization. + std: shape (3), std value for normalization. + interpolation: image interpolation mode. + + """ imgH, imgW = image_shape # todo: change to 0 and modified image shape max_wh_ratio = imgW * 1.0 / imgH h, w = img.shape[0], img.shape[1] c = img.shape[2] ratio = w * 1.0 / h - + max_wh_ratio = max(max_wh_ratio, ratio) imgW = int(imgH * max_wh_ratio) if math.ceil(imgH * ratio) > imgW: resized_w = imgW @@ -403,36 +445,93 @@ def resize_norm_img_chinese(img, image_shape): resized_w = int(math.ceil(imgH * ratio)) resized_image = cv2.resize(img, (resized_w, imgH)) - padding_im = np.zeros((imgH, imgW, c), dtype=resized_image.dtype) - padding_im[:, 0:resized_w, :] = resized_image valid_ratio = min(1.0, float(resized_w / imgW)) - return padding_im, valid_ratio + if norm_before_pad: + resized_image = (resized_image - mean) / std -# TODO: remove infer_mode and character_dict_path if they are not necesary -class RecResizeImg(object): + padded_img = np.zeros((imgH, imgW, c), dtype=resized_image.dtype) + padded_img[:, 0:resized_w, :] = resized_image + + if not norm_before_pad: + padded_img = (padded_img - mean) / std + + return padded_img, valid_ratio + + +class RecResizeNormImg(object): """adopted from paddle - resize, convert from hwc to chw, rescale pixel value to -1 to 1 + Resize and normalize image, and pad image if needed. + + Args: + image_shape: image shape after resize, in (C, H, W) + padding (bool): if Ture, resize while preserving the H/W ratio, then pad the blank. + norm_before_pad (bool): if True, normalize the image array before padding. + mean: shape (3), mean value for normalization. + std: shape (3), std value for normalization. + interpolation: image interpolation mode. + norm_before_pad: If True, perform normalization before padding \ + (by doing so, the padding values will beall zero. Good practice.). \ + Otherwise, per Default: False """ - def __init__(self, image_shape, infer_mode=False, character_dict_path=None, padding=True, **kwargs): + def __init__( + self, + image_shape, + infer_mode=False, + character_dict_path=None, + padding=True, + norm_before_pad=False, + mean=[127.0, 127.0, 127.0], + std=[127.0, 127.0, 127.0], + **kwargs, + ): self.image_shape = image_shape self.infer_mode = infer_mode self.character_dict_path = character_dict_path self.padding = padding + self.norm_before_pad = norm_before_pad + self.mean = np.array(mean, dtype="float32") + self.std = np.array(std, dtype="float32") def __call__(self, data): img = data["image"] if self.infer_mode and self.character_dict_path is not None: - norm_img, valid_ratio = resize_norm_img_chinese(img, self.image_shape) + norm_img, valid_ratio = resize_norm_img_chinese( + img, self.image_shape, self.norm_before_pad, self.mean, self.std + ) else: - norm_img, valid_ratio = resize_norm_img(img, self.image_shape, self.padding) + norm_img, valid_ratio = resize_norm_img( + img, + self.image_shape, + self.padding, + self.norm_before_pad, + self.mean, + self.std, + ) data["image"] = norm_img data["valid_ratio"] = valid_ratio - # TODO: data['shape_list'] = ? return data +# TODO: remove infer_mode and character_dict_path if they are not necesary +class RecResizeImg(RecResizeNormImg): + """ + This is to make compatible with older version code that uses RecResizeImg, which is to be updated. + """ + + def __init__(self, image_shape, infer_mode=False, character_dict_path=None, padding=True, **kwargs): + super().__init__( + image_shape, + infer_mode, + character_dict_path, + padding, + norm_befoer_pad=False, + mean=[0.0, 0.0, 0.0], + std=[1.0, 1.0, 1.0], + ) + + class SVTRRecResizeImg(object): def __init__(self, image_shape, padding=True, **kwargs): self.image_shape = image_shape @@ -511,9 +610,7 @@ def __call__(self, data): # TODO: norm before padding - data["shape_list"] = np.array( - [h, w, resize_h / h, resize_w / w], dtype=np.float32 - ) # TODO: reformat, currently align to det + data["shape_list"] = [h, w, resize_h / h, resize_w / w] # TODO: reformat, currently align to det if self.norm_before_pad: resized_img = self.norm(resized_img)