R2CNN完成自然场景OCR文本检测超详细实现过程

一.下载源码并配置环境
Requirements
tensorflow >= 1.2
cuda >=8.0
python3
opencv(cv2)
1.下载源码
R2CNN_Faster-RCNN_Tensorflow源码:git clone https://github.com/DetectionTeamUCAS/R2CNN_Faster-RCNN_Tensorflow
2.编译(Compile)
1)cd $PATH_ROOT/libs/box_utils/
python setup.py build_ext --inplace
2)cd $PATH_ROOT/libs/box_utils/cython_utils
python setup.py build_ext --inplace
3.下载预训练模型
please download resnet50_v1resnet101_v1 pre-trained models on Imagenet, put it to data/pretrained_weights。(我用的是resnet101_v1)
下载链接:http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz
二.制作自己的数据
我训练用的是ICDAR2017数据集
data format
├── VOCdevkit
│ ├── VOCdevkit_train
│ │ ├── Annotation (xml文件)
│ │ ├── JPEGImages (jpg文件)
1.xml生成:(将txt文件批量转换为xml文件,xml文件中是4个点坐标8个数值)
用的是data/io/ICDAR2015/txt2xml.py,源码中txt2xml.py没有对txt文件中label是“##”or“###”等难以识别的目标过滤掉,论文《R2CNN: Rotational Region CNN for Orientation Robust Scene Text Detection》中提到“ As ICDAR 2015 training dataset contains difficult texts that is hard to detect which are labeled as “###”, we only use those readable text for training.”,所以我和我的伙伴在训练时对txt2xml.py做了一些修改:1)过滤掉了 "###"数据2)而且还对txt文件中坐标值存在非整型的数据进行了过滤;3)制作好tfrecord文件训练过程中报错“ValueError: attempt to get argmax of an empty sequence”,百度原因可能是“图片的长宽比(width/height)要在一个合理的范围之间,就是太过瘦长的图片不要,进行了过滤。”修改后的txt2xml.py代码如下:

import os
from xml.dom.minidom import Document
from xml.dom.minidom import parse
import xml.dom.minidom
import numpy as np
import csv
import cv2
import codecs
import re

def WriterXMLFiles(filename, path, box_list, labels, w, h, d):

    # dict_box[filename]=json_dict[filename]
    doc = xml.dom.minidom.Document()
    root = doc.createElement('annotation')
    doc.appendChild(root)

    foldername = doc.createElement("folder")
    foldername.appendChild(doc.createTextNode("JPEGImages"))
    root.appendChild(foldername)


    nodeFilename = doc.createElement('filename')
    nodeFilename.appendChild(doc.createTextNode(filename))
    root.appendChild(nodeFilename)

    pathname = doc.createElement("path")
    pathname.appendChild(doc.createTextNode("xxxx"))
    root.appendChild(pathname)

    sourcename=doc.createElement("source")

    databasename = doc.createElement("database")
    databasename.appendChild(doc.createTextNode("Unknown"))
    sourcename.appendChild(databasename)

    annotationname = doc.createElement("annotation")
    annotationname.appendChild(doc.createTextNode("xxx"))
    sourcename.appendChild(annotationname)

    imagename = doc.createElement("image")
    imagename.appendChild(doc.createTextNode("xxx"))
    sourcename.appendChild(imagename)

    flickridname = doc.createElement("flickrid")
    flickridname.appendChild(doc.createTextNode("0"))
    sourcename.appendChild(flickridname)

    root.appendChild(sourcename)

    nodesize = doc.createElement('size')
    nodewidth = doc.createElement('width')
    nodewidth.appendChild(doc.createTextNode(str(w)))
    nodesize.appendChild(nodewidth)
    nodeheight = doc.createElement('height')
    nodeheight.appendChild(doc.createTextNode(str(h)))
    nodesize.appendChild(nodeheight)
    nodedepth = doc.createElement('depth')
    nodedepth.appendChild(doc.createTextNode(str(d)))
    nodesize.appendChild(nodedepth)
    root.appendChild(nodesize)

    segname = doc.createElement("segmented")
    segname.appendChild(doc.createTextNode("0"))
    root.appendChild(segname)

    for box, label in zip(box_list, labels):

        nodeobject = doc.createElement('object')
        nodename = doc.createElement('name')
        nodename.appendChild(doc.createTextNode(label))
        nodeobject.appendChild(nodename)
        nodebndbox = doc.createElement('bndbox')
        nodex1 = doc.createElement('x1')
        nodex1.appendChild(doc.createTextNode(str(box[0])))
        nodebndbox.appendChild(nodex1)
        nodey1 = doc.createElement('y1')
        nodey1.appendChild(doc.createTextNode(str(box[1])))
        nodebndbox.appendChild(nodey1)
        nodex2 = doc.createElement('x2')
        nodex2.appendChild(doc.createTextNode(str(box[2])))
        nodebndbox.appendChild(nodex2)
        nodey2 = doc.createElement('y2')
        nodey2.appendChild(doc.createTextNode(str(box[3])))
        nodebndbox.appendChild(nodey2)
        nodex3 = doc.createElement('x3')
        nodex3.appendChild(doc.createTextNode(str(box[4])))
        nodebndbox.appendChild(nodex3)
        nodey3 = doc.createElement('y3')
        nodey3.appendChild(doc.createTextNode(str(box[5])))
        nodebndbox.appendChild(nodey3)
        nodex4 = doc.createElement('x4')
        nodex4.appendChild(doc.createTextNode(str(box[6])))
        nodebndbox.appendChild(nodex4)
        nodey4 = doc.createElement('y4')
        nodey4.appendChild(doc.createTextNode(str(box[7])))
        nodebndbox.appendChild(nodey4)

        # ang = doc.createElement('angle')
        # ang.appendChild(doc.createTextNode(str(angle)))
        # nodebndbox.appendChild(ang)
        nodeobject.appendChild(nodebndbox)
        root.appendChild(nodeobject)
    fp = open(os.path.join(path,filename), 'w')
    doc.writexml(fp, indent='\n')
    fp.close()


def load_annoataion(txt_path):
    boxes, labels = [], []
    fr = codecs.open(txt_path, 'r', 'utf-8')
    lines = fr.readlines()

    for line in lines:
        b = line.strip('\ufeff').strip('\xef\xbb\xbf').strip('$').split(',')[:8]
        print(len(b))
        if (len(b)==8):
            # print(int(float(b)))
            b=[int(float(item)) for item in b]
            # print(b)
            labelName=line.strip('\ufeff').strip('\xef\xbb\xbf').strip('$').split(',')[-1]
            line = list(map(int, b))
            if "#" not in labelName:
                print(labelName,line)
                if(len(line)==0):
                    print("#############################",line)
                    # break
                else:
                    boxes.append(line)
                    labels.append('text')
            else:
                continue
        elif (len(b)==5):
            pass

    return np.array(boxes), np.array(labels)


if __name__ == "__main__":
    txt_path = '/home/trainingai/zyang/R2CNN_Faster-RCNN_Tensorflow/VOCdevkit/meituan/txt'
    xml_path = '/home/trainingai/zyang/R2CNN_Faster-RCNN_Tensorflow/VOCdevkit/meituan/xml'
    img_path = '/home/trainingai/zyang/R2CNN_Faster-RCNN_Tensorflow/VOCdevkit/meituan/useful_image'
    print(os.path.exists(txt_path))
    txts = os.listdir(txt_path)
    for count, t in enumerate(txts):
        print("**********txt name:", t)
        boxes, labels = load_annoataion(os.path.join(txt_path, t))
        if (len(boxes)==0):
            # xml_name = t.replace('.txt', '.xml')
            # img_name = t.replace('.txt', '.jpg')
            print("&&&&&&&&&&&&&&&&&&&&&&&&&&",boxes,img_name)
            # break
        else:
            xml_name = t.replace('.txt', '.xml')
            img_name = t.replace('.txt', '.jpg')
            img = cv2.imread(os.path.join(img_path, img_name.split('gt_')[-1]))
            h, w, d = img.shape
            #0.462-6.828
            if (w/h>0.8 and w/h<6.828):
                WriterXMLFiles(xml_name.split('gt_')[-1], xml_path, boxes, labels, w, h, d)
                cv2.imwrite("/home/trainingai/zyang/R2CNN_Faster-RCNN_Tensorflow/VOCdevkit/meituan/voc_image/"+img_name,img)
                if count % 1000 == 0:
                    print(count)
            else:
                print("**************************************",img_name)

2.制作tfrecord文件
1): 改/R2CNN_Faster-RCNN_Tensorflow/libs/configs/cfgs.py
第9行 NET_NAME=‘resnet_v1_101’ 作者用的是101
第8行 VERSION = ‘RRPN_ICDAR2017_v1’
第64行 DATASET_NAME = ‘ICDAR2017’
CLASS_NUM改为1(文本检测label只有一个:text)
2): 在/libs/label_name_dict/label_dict.py里面添加dataset

elif cfgs.DATASET_NAME=='ICDAR2017':
      NAME_LABEL_MAP={
     
             'background' : 0,
             'text' : 1
      }

3):在R2CNN_Faster-RCNN_Tensorflow/data/io/convert_data_to_tfrecord.py文件第13-20行更改各个参数路径和名字,如下:

ROOT_PATH = '/home/trainingai/zyang/R2CNN_Faster-RCNN_Tensorflow'

tf.app.flags.DEFINE_string('VOC_dir', '/home/trainingai/zyang/R2CNN_Faster-RCNN_Tensorflow/VOCdevkit/VOCdevkit_train/', 'Voc dir')
tf.app.flags.DEFINE_string('xml_dir', '/home/trainingai/zyang/R2CNN_Faster-RCNN_Tensorflow/VOCdevkit/VOCdevkit_train/Annotation', 'xml dir')
tf.app.flags.DEFINE_string('image_dir', '/home/trainingai/zyang/R2CNN_Faster-RCNN_Tensorflow/VOCdevkit/VOCdevkit_train/JPEGImages', 'image dir')
tf.app.flags.DEFINE_string('save_name', 'train', 'save name')
tf.app.flags.DEFINE_string('save_dir', ROOT_PATH + '/data/tfrecords/', 'save name')
tf.app.flags.DEFINE_string('img_format', '.jpg', 'format of image')
tf.app.flags.DEFINE_string('dataset', 'ICDAR2017', 'dataset')
FLAGS = tf.app.flags.FLAGS

4)在3):在R2CNN_Faster-RCNN_Tensorflow/data/io/read_tfrecord.py第75行dataset_name中需要添加自己的dataset name
上述修改完毕后cd $R2CNN_Faster-RCNN_Tensorflow/data/io/ 运行
python convert_data_to_tfrecord.py --VOC_dir=’***/VOCdevkit/VOCdevkit_train/’ --save_name=‘train’ --img_format=’.jpg’ --dataset=‘ICDAR2017’
至此,tfrecord文件生成成功,很大!
三.训练自己的模型
对cfgs.py文件进行修改,一些超参数如下:

EPSILON = 1e-4
MOMENTUM = 0.9
LR = 0.0001
DECAY_STEP = [190000, 380000]  # 90000, 120000
MAX_ITERATION = 500000

# -------------------------------------------- Data_preprocess_config
DATASET_NAME = 'ICDAR2017'  # 'ship', 'spacenet', 'pascal', 'coco'
PIXEL_MEAN = [123.68, 116.779, 103.939]  # R, G, B. In tf, channel is RGB. In openCV, channel is BGR
IMG_SHORT_SIDE_LEN = 720
IMG_MAX_LENGTH = 1280
CLASS_NUM = 1

# --------------------------------------------- Network_config
BATCH_SIZE = 1
INITIALIZER = tf.random_normal_initializer(mean=0.0, stddev=0.01)
BBOX_INITIALIZER = tf.random_normal_initializer(mean=0.0, stddev=0.001)
WEIGHT_DECAY = 0.0001

初始学习率设置LR = 0.0001,一开始设置0.001训练显示total loss一直显示nan,原因是初始学习率过大;训练步数设置MAX_ITERATION = 500000,根据自己的情况合理设置;图片size设置IMG_SHORT_SIDE_LEN = 720 IMG_MAX_LENGTH = 1280和论文保持一致;设置BATCH_SIZE = 1,目前模型只支持BATCH_SIZE = 1
修改完毕后cd $R2CNN_Faster-RCNN_Tensorflow/tools目录下运行python train.py即可!
训练过程如下:
在这里插入图片描述
接下来就是等待漫长的训练过程了。。。

四.调用自己的训练模型进行测试
cfgs.py中预训练模型改成自己训练好的模型,如下

# PRETRAINED_CKPT = ROOT_PATH + '/data/pretrained_weights/' + weights_name + '.ckpt'                      # train
# PRETRAINED_CKPT = ROOT_PATH + '/data/pretrained_weights/our_pretrained_weights/' + weights_name + '.ckpt'         # train
PRETRAINED_CKPT = ROOT_PATH + '/output/trained_weights/RRPN_ICDAR2017_v1_19.4w/' + weights_name + '.ckpt'        # test
TRAINED_CKPT = os.path.join(ROOT_PATH, 'output/summary')

cd 到R2CNN_Faster-RCNN_Tensorflow/tools目录下,
运行python inference.py --data_dir=‘R2CNN_Faster-RCNN_Tensorflow/tools/inference_image/’
–gpu=‘0’
到此配置模型环境–制作数据集–训练模型–调用模型测试各个环节已全部完毕!

你可能感兴趣的:(OCR文本检测,python,深度学习,tensorflow)