pytorch-psenet实现 并训练自己的数据集

深度学习ocr交流qq群:1020395892

作者的github链接:https://github.com/whai362/PSENet
Requirements
Python 2.7
PyTorch v0.4.1+
pyclipper
Polygon2
OpenCV 3.4 (for c++ version pse)
opencv-python 3.4
首先是需要安装这些依赖,因为我们自己电脑上面已经安装了各种版本的opencv,python,不建议在自己电脑上面再安装,可以创建虚拟环境anaconda
anaconda安装自行百度
anaconda 常用命令参考链接:https://blog.csdn.net/yang332233/article/details/90545561

conda create -n PSENet_1 python=2.7 OpenCV=3.4

这里创建虚拟环境可以同时安装必要的包,当然上面命令都是经过n次失败之后摸索出来的,一开始这么写的 conda create -n PSENet_1 python=2.7 OpenCV=3.4 opencv-python=3.4,报错,可能opencv-python没有吧

下一步重新打开一个终端,

source activate PSENet_1

可以看到(PSENet_1) yhl@B-Y:~$ 说明可以了
接着安装pytorch 我电脑配置是ubuntu16.04,显卡 1080,cuda8.0
到pytorch官网可以看到已经没有cuda-8了,但是历史版本有,直接conda install pytorch torchvision cudatoolkit=9.0 -c pytorch改成8就可以

conda install pytorch torchvision cudatoolkit=8.0 -c pytorch

然后完成,敲 pip -V 看是否是对应python2.7, pip -V
pip 19.1.1 from /media/data_1/software_install/Anaconda0515/envs/PSENet_1/lib/python2.7/site-packages/pip (python 2.7)

pip install pyclipper
 pip install Polygon2

需求表上面的opencv-python 3.4 我没法安装制定的3.4版本,之前好像敲的pip install opencv-python 安装的是4.1版本的,实践证明这个可以不要安装,安装了opencv3.4就已经可以了
转到PSENet根目录,并训练,前提是你需要准备好数据,数据需要放在根目录下面的data文件夹下,data-ICDAR2015-Challenge4-( ch4_training_localization_transcription_gt ch4_training_images ch4_test_images)
还有ctw-1500的数据集比较难下载,还好我问同事要到了,晚点上传发链接到这里吧
同样的,把自己的数据也按照它存放的文件路径就可以训练自己的数据集

 cd /media/data_1/project/2019/project/PSENet #
 CUDA_VISIBLE_DEVICES=0 python train_ic15.py --batch_size 2

报错 ImportError: No module named matplotlib

pip install matplotlib

再敲上面训练的命令就可以训练了,显示如下
CUDA_VISIBLE_DEVICES=0 python train_ic15.py --batch_size 2
checkpoint path: checkpoints/ic15_resnet50_bs_2_ep_600
init lr: 0.00100000
('schedule: ', [200, 400])
Training from scratch.

Epoch: [1 | 600] LR: 0.001000
/media/data_1/Yang/software_install/Anaconda0515/envs/PSENet_1/lib/python2.7/site-packages/torch/nn/functional.py:2351: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
warnings.warn(“nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.”)
/media/data_1/Yang/software_install/Anaconda0515/envs/PSENet_1/lib/python2.7/site-packages/torch/nn/functional.py:2423: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
“See the documentation of nn.Upsample for details.”.format(mode))
(1/6849) Batch: 1.927s | TOTAL: 0min | ETA: 220min | Loss: 0.8001 | Acc_t: 0.6196 | IOU_t: 0.2853 | IOU_k: 0.4976
(21/6849) Batch: 0.516s | TOTAL: 0min | ETA: 59min | Loss: 0.7719 | Acc_t: 0.6852 | IOU_t: 0.3115 | IOU_k: 0.4986
(41/6849) Batch: 0.481s | TOTAL: 0min | ETA: 55min | Loss: 0.7681 | Acc_t: 0.7403 | IOU_t: 0.3354 | IOU_k: 0.4985
再测试

CUDA_VISIBLE_DEVICES=0 python test_ic15.py --scale 1 --resume [path of model]

可以看到根目录生成了outputs文件夹,里面有效果图:下面是ICDAR2015测试集效果图:train_ic15.py
pytorch-psenet实现 并训练自己的数据集_第1张图片下面是我自己训练的数据效果图:这个是用train_ctw1500.py训练的
pytorch-psenet实现 并训练自己的数据集_第2张图片

###下面这个错误当正确安装了opencv的时候可以无视,创建虚拟环境的时候 conda create -n PSENet_1 python=2.7 OpenCV=3.4 这么写好像就没有问题
看起来一帆风顺,其实一开始弄的时候各种报错,一开始报错说pse文件夹下connectedComponents找不到,大概是这个错误,我百度了一下发现这个函数是opencv的,在pse文件夹下有MakeFile敲make 确实报错,看它里面OPENCV = pkg-config --cflags --libs opencv 这个语句我看它好像是调用系统下面的opencv了,各种尝试,最后确定makefile这么写可以通过:
CXXFLAGS = -I include -std=c++11 -O3

DEPS = lanms.h $(shell find include -xtype f)
CXX_SOURCES = adaptor.cpp include/clipper/clipper.cpp

#OPENCV = pkg-config --cflags --libs opencv
INCLUDES = -I/media/data_1/Yang/software_install/Anaconda0515/envs/PSENet/include
LIBS = -lopencv_core -lopencv_imgproc -lopencv_highgui -lopencv_ml
LIBDIRS = -L/media/data_1/Yang/software_install/Anaconda0515/envs/PSENet/lib

LIB_SO = adaptor.so

$(LIB_SO): $(CXX_SOURCES) $(DEPS)
$(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC $(OPENCV)

clean:
rm -rf $(LIB_SO)

psenet的后处理理解–戳我

由于本质上是分割,我已经把代码改成接受任意个数的标注点就可以训练。不需要一会儿icdar的4个标注点,CTD的14个标注点,这些太烦了。任意个标注点不香吗?

这里说明一下ctd标签32个数值的含义:
ctd坐标4 + 28, 前4个点是外接矩形的左上、右下的坐标,后面14个点是上排的7个坐标相对于左上坐标的偏移量,再后面的14个点是下排的7个坐标相对于左上坐标的偏移量

另外我自己的数据集40万标注好的,训练出来的模型对所有的文字定位效果都很鲁棒!

暂时想到的就是这些了,后面的有想法再补充,欢迎一起讨论。
小弟不才,同时谢谢友情赞助!
pytorch-psenet实现 并训练自己的数据集_第3张图片

你可能感兴趣的:(pytorch)