Skip to content

nn.ConvTranspose2d #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 21 additions & 103 deletions README.MD
Original file line number Diff line number Diff line change
@@ -1,128 +1,46 @@
# Real-time Scene Text Detection with Differentiable Binarization

**note**: some code is inherited from [MhLiao/DB](https://github.com/MhLiao/DB)
**note**: 原始版本 [DBNet.pytorch](https://github.com/WenmuZhou/DBNet.pytorch)

[中文解读](https://zhuanlan.zhihu.com/p/94677957)

![network](imgs/paper/db.jpg)

## update
2020-06-07: 添加灰度图训练,训练灰度图时需要在配置里移除`dataset.args.transforms.Normalize`
## 安装环境

## Install Using Conda
```
conda env create -f environment.yml
git clone https://github.com/WenmuZhou/DBNet.pytorch.git
cd DBNet.pytorch/
```
请参考原始版本的[Readme](https://github.com/WenmuZhou/DBNet.pytorch/blob/master/README.MD)

or
## Install Manually
```bash
conda create -n dbnet python=3.6
conda activate dbnet

conda install ipython pip

# python dependencies
pip install -r requirement.txt
## 修改之处

# install PyTorch with cuda-10.1
# Note that you can change the cudatoolkit version to the version you want.
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch

# clone repo
git clone https://github.com/WenmuZhou/DBNet.pytorch.git
cd DBNet.pytorch/
本repo为了可以使用tensorRT加速,将反卷积操作全部改为upsample。比如

```python
# 原始版本
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2), # 上采样两倍
# 修改版本
nn.Upsample(scale_factor=2, mode='nearest'),
```

## Requirements
* pytorch 1.4+
* torchvision 0.5+
* gcc 4.9+

## Download
更多的修改,请看代码:

TBD

## Data Preparation

Training data: prepare a text `train.txt` in the following format, use '\t' as a separator
```
./datasets/train/img/001.jpg ./datasets/train/gt/001.txt
models/head/DBHead.py
models/model.py
models/neck/FPN.py
```

Validation data: prepare a text `test.txt` in the following format, use '\t' as a separator
```
./datasets/test/img/001.jpg ./datasets/test/gt/001.txt
```
- Store images in the `img` folder
- Store groundtruth in the `gt` folder

The groundtruth can be `.txt` files, with the following format:
```
x1, y1, x2, y2, x3, y3, x4, y4, annotation
```


## Train
1. config the `dataset['train']['dataset'['data_path']'`,`dataset['validate']['dataset'['data_path']`in [config/icdar2015_resnet18_fpn_DBhead_polyLR.yaml](cconfig/icdar2015_resnet18_fpn_DBhead_polyLR.yaml)
* . single gpu train
```bash
bash singlel_gpu_train.sh
```
* . Multi-gpu training
```bash
bash multi_gpu_train.sh
```
## Test

[eval.py](tools/eval.py) is used to test model on test dataset

1. config `model_path` in [eval.sh](eval.sh)
2. use following script to test
```bash
bash eval.sh
```

## Predict
[predict.py](tools/predict.py) Can be used to inference on all images in a folder
1. config `model_path`,`input_folder`,`output_folder` in [predict.sh](predict.sh)
2. use following script to predict
```
bash predict.sh
```
You can change the `model_path` in the `predict.sh` file to your model location.

tips: if result is not good, you can change `thre` in [predict.sh](predict.sh)

The project is still under development.

<h2 id="Performance">Performance</h2>

### [ICDAR 2015](http://rrc.cvc.uab.es/?ch=4)
only train on ICDAR2015 dataset

| Method | image size (short size) |learning rate | Precision (%) | Recall (%) | F-measure (%) | FPS |
|:--------------------------:|:-------:|:--------:|:--------:|:------------:|:---------------:|:-----:|
| SynthText-Defrom-ResNet-18(paper) | 736 |0.007 | 86.8 | 78.4 | 82.3 | 48 |
| ImageNet-resnet18-FPN-DBHead |736 |1e-3| 87.03 | 75.06 | 80.6 | 43 |
| ImageNet-Defrom-Resnet18-FPN-DBHead |736 |1e-3| 88.61 | 73.84 | 80.56 | 36 |
| ImageNet-resnet50-FPN-DBHead |736 |1e-3| 88.06 | 77.14 | 82.24 | 27 |
| ImageNet-resnest50-FPN-DBHead |736 |1e-3| 88.18 | 76.27 | 81.78 | 27 |
## 模型

修改后代码训练的模型地址:[渣云:访问密码 myj4 ](https://pan.baidu.com/s/10Ff-0AJkkpC9jGWdNSsN6g)

### examples
TBD
目前没有训练完成,相比原版模型(1200epoch),只训练了500epoch。精度:90.0 召回率:68.2。

可以自己去训练。

### todo
- [x] mutil gpu training
## TensorRT版本

### reference
1. https://arxiv.org/pdf/1911.08947.pdf
2. https://github.com/WenmuZhou/PANet.pytorch
3. https://github.com/MhLiao/DB
https://github.com/BaofengZan/DBNet-TensorRT

**If this repository helps you,please star it. Thanks.**
20 changes: 15 additions & 5 deletions models/head/DBHead.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,26 @@
# @Author : zhoujun
import torch
from torch import nn
import torch.nn.functional as F

class DBHead(nn.Module):
def __init__(self, in_channels, out_channels, k = 50):
def __init__(self, in_channels, out_channels, k = 50): # debug ==> 256 2 k=50
super().__init__()
self.k = k
self.binarize = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
# ConvTranspose2d (self, in_channels, out_channels, kernel_size, stride=1,
# padding=0, output_padding=0, groups=1, bias=True,
# dilation=1, padding_mode='zeros'):
#nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2), # 上采样两倍
nn.Upsample(scale_factor=2, mode='nearest'),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2),
#nn.ConvTranspose2d(in_channels // 4, 1, 2, 2),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(in_channels//4, 1, 3, padding=1), # 311 大小不变
nn.Sigmoid())
self.binarize.apply(self.weights_init)

Expand All @@ -41,9 +48,10 @@ def weights_init(self, m):
m.bias.data.fill_(1e-4)

def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
in_channels = inner_channels
in_channels = inner_channels # 256
if serial:
in_channels += 1

self.thresh = nn.Sequential(
nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.BatchNorm2d(inner_channels // 4),
Expand All @@ -67,7 +75,9 @@ def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
return nn.Sequential(module_list)
else:
return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
#return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
return nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(in_channels, out_channels, 3, 1, 1))

def step_function(self, x, y):
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
4 changes: 3 additions & 1 deletion models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def forward(self, x):
backbone_out = self.backbone(x)
neck_out = self.neck(backbone_out)
y = self.head(neck_out)
y = F.interpolate(y, size=(H, W), mode='bilinear', align_corners=True)
# y = F.interpolate(y, size=(H, W), mode='bilinear', align_corners=True)
# 有点区别,就是F中的是一个函数,在nn.sequential()中,不能作为一个层,而nn.upsample中的则可以
y = F.interpolate(y, size=(H, W)) # 使用最近邻训练的可以用TRTAPI实现
return y


Expand Down
36 changes: 26 additions & 10 deletions models/neck/FPN.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,25 @@
class FPN(nn.Module):
def __init__(self, in_channels, inner_channels=256, **kwargs):
"""
:param in_channels: 基础网络输出的维度
:param in_channels: 基础网络输出的维度 [64, 128, 256, 512]
:param kwargs:
"""
super().__init__()
inplace = True
self.conv_out = inner_channels
inner_channels = inner_channels // 4
inner_channels = inner_channels // 4 # 256 // 4 = 64
# reduce layers
self.reduce_conv_c2 = ConvBnRelu(in_channels[0], inner_channels, kernel_size=1, inplace=inplace)
self.reduce_conv_c3 = ConvBnRelu(in_channels[1], inner_channels, kernel_size=1, inplace=inplace)
self.reduce_conv_c4 = ConvBnRelu(in_channels[2], inner_channels, kernel_size=1, inplace=inplace)
self.reduce_conv_c5 = ConvBnRelu(in_channels[3], inner_channels, kernel_size=1, inplace=inplace)
# Smooth layers
self.smooth_p4 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)
self.smooth_p4 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace) # 311
self.smooth_p3 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)
self.smooth_p2 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)

#self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

self.conv = nn.Sequential(
nn.Conv2d(self.conv_out, self.conv_out, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(self.conv_out),
Expand All @@ -39,23 +41,37 @@ def forward(self, x):
c2, c3, c4, c5 = x
# Top-down
p5 = self.reduce_conv_c5(c5)
p4 = self._upsample_add(p5, self.reduce_conv_c4(c4))
#p4 = self._upsample_add(p5, self.reduce_conv_c4(c4))
c4_1 = self.reduce_conv_c4(c4)
p4_1 = F.upsample(p5, size=c4_1.size()[2:])
p4 = p4_1 + c4_1

p4 = self.smooth_p4(p4)
p3 = self._upsample_add(p4, self.reduce_conv_c3(c3))

#p3 = self._upsample_add(p4, self.reduce_conv_c3(c3))
c3_1 = self.reduce_conv_c3(c3)
p3_1 = F.upsample(p4, size=c3_1.size()[2:])
p3 = p3_1 + c3_1
p3 = self.smooth_p3(p3)
p2 = self._upsample_add(p3, self.reduce_conv_c2(c2))
#p2 = self._upsample_add(p3, self.reduce_conv_c2(c2))
c2_1 = self.reduce_conv_c2(c2)
p2_1 = F.upsample(p3, size=c2_1.size()[2:])
p2 = p2_1 + c2_1
p2 = self.smooth_p2(p2)

x = self._upsample_cat(p2, p3, p4, p5)
x = self.conv(x)
return x

def _upsample_add(self, x, y):
return F.interpolate(x, size=y.size()[2:]) + y
return F.upsample(x, size=y.size()[2:]) + y

def _upsample_cat(self, p2, p3, p4, p5):
h, w = p2.size()[2:]
p3 = F.interpolate(p3, size=(h, w))
p4 = F.interpolate(p4, size=(h, w))
p5 = F.interpolate(p5, size=(h, w))
#p3 = F.interpolate(p3, size=(h, w))
p3 = F.upsample(p3, size=(h, w))
#p4 = F.interpolate(p4, size=(h, w))
p4 = F.upsample(p4, size=(h, w))
#p5 = F.interpolate(p5, size=(h, w))
p5 = F.upsample(p5, size=(h, w))
return torch.cat([p2, p3, p4, p5], dim=1)
Loading