github faster-rcnn windows项目:https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3.5
按照项目说明,配置项目。过程中会遇到问题:
pil 库无法安装,因为当前python版本为3.5,pil库最多支持到2.7,pip无法安装,conda安装会覆盖当前python版本为2.7 切记!!解决方法:32.7版本后直接下载pillow库即可。
博主环境:python3.5,tensorflow1.13-gpu环境,
VOCDevkit2007 文件夹为当前项目使用到的数据
Annotations中的xml为每张图片的标记数据,分类名,检测框坐标。文件名对于图片名。
JPEGImages 为所有图片原数据
ImageSets\Main\trainval.txt 为图片名 索引文件。数据加载 图片名从这个文件中读取。
运行train.py文件进行 训练。config/config.py为配置文件模型保存和batch_size 学习率等都保存在这。模型输出文件夹在default\voc_2007_trainval\default,训练完毕需要将模型文件复制到\output\vgg16\voc_2007_trainval+voc_2012_trainval\default下,demo运行使用模型
demo.py文件运行,直接运行报错,找不到模型文件。demo中用到的数据在data/demo/下。
错误解决,修改demo文件:
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_40000.ckpt',), 'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
代码中根据网络模式 将模型名修改成自己的模型名,迭代训练次数不一样,模型名不一样。
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]', choices=NETS.keys(), default='vgg16')
default 修改成vgg16。
其他错误,一些缺包的错误,安装包即可。一些无法安装的包,一般都是替代包。跟pil一样。仔细看说明,错误提示中会有提示替代包
到此,整个原生faster-rcnn 已经算是完成了。
WIDER FACE 图片库,下载地址http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/
这三个文件需要下载。训练数据和验证数据annotations是标记数据。测试数据看自己的需求。
为了方便faster rcnn做训练,必须先把wider face数据库转为voc2007格式,也就是和原生数据一样的格式最后会附代码。直接转换。
将data\VOCDevkit2007\VOC2007 下所有文件删除
下载的数据解压到data\VOCDevkit2007\Wider_face\文件夹下
新建数据格式转换文件
文件代码在最后。
运行脚本缺包需要安装,需要等待一段时间,复制图片,制作标签等。
脚本运行完毕data\VOCDevkit2007\Wider_face 文件夹下得到三个文件夹就是voc格式
将三个文件夹移动到\data\VOCDevkit2007\VOC2007文件夹下
进入到data\VOCDevkit2007\VOC2007\ImageSets\Main下
var.txt 文件中的内容复制到train.txt中 。修改train.txt文件名为:trainval.txt
到此数据已完全符合训练数据格式。
修改lib/datasets/pascal_voc.py 中self._classes,添加自己的分类比如,face,
开始训练,如果还是报错 可尝试删除data/cache文件夹内容再次运行。
/demo.py 文件中修改 CLASSES 添加分类 face
修改分类数量:net.create_architecture(sess, “TEST”, 22,
完毕!!!
解决方法:报错原因是fg_inds和bg_inds的数量都小于0,这张图片没办法训练了,所以直接跳过这张图。办法是调整config.py里的roi_bg_threshold_high和roi_bg_threshold_low,一般把roi_bg_threshold_low改成0.0就不会出现这个问题。
解决方法:打开lib/database/pascal_voc.py文件,每一行后面的-1删除。原因是因为我们制作的xml文件中有些框的坐标是从左上角开始的,也就是(0,0)如果再减一就会出现log(-1)的情况。
改完之后就不会出现RuntimeWarning: invalid value encountered in log targets_dw = np.log(gt_widths / ex_widths)这个问题了,loss也不会出现等于nan了,如果还出现loss=nan,可以再试试调小学习率以及各个损失项的占比重。亲测有效。。
"""
Created on 19-4-18
@author: 段大帅
"""
from skimage import io
import shutil
import random
import os
import string
headstr = """\
VOC2007
%06d.jpg
NULL
company
%d
%d
%d
0
"""
objstr = """\
"""
tailstr = '''\
'''
def all_path(filename):
return os.path.join('Wider_face', filename)
def writexml(idx, head, bbxes, tail):
filename = all_path("Annotations/%06d.xml" % (idx))
f = open(filename, "w")
f.write(head)
for bbx in bbxes:
f.write(objstr % ('face', bbx[0], bbx[1], bbx[0] + bbx[2], bbx[1] + bbx[3]))
f.write(tail)
f.close()
def clear_dir():
if shutil.os.path.exists(all_path('Annotations')):
shutil.rmtree(all_path('Annotations'))
if shutil.os.path.exists(all_path('ImageSets')):
shutil.rmtree(all_path('ImageSets'))
if shutil.os.path.exists(all_path('JPEGImages')):
shutil.rmtree(all_path('JPEGImages'))
shutil.os.mkdir(all_path('Annotations'))
shutil.os.makedirs(all_path('ImageSets/Main'))
shutil.os.mkdir(all_path('JPEGImages'))
def excute_datasets(idx, datatype):
f = open(all_path('ImageSets/Main/' + datatype + '.txt'), 'a')
f_bbx = open(all_path('wider_face_split/wider_face_' + datatype + '_bbx_gt.txt'), 'r')
while True:
filename = f_bbx.readline().strip('\n')
if not filename:
break
try:
im = io.imread(all_path('WIDER_' + datatype + '/images/'+filename))
except IOError:
print('错误文件名已跳过,',filename)
continue
head = headstr % (idx, im.shape[1], im.shape[0], im.shape[2])
nums = f_bbx.readline().strip('\n')
bbxes = []
for ind in range(int(nums)):
bbx_info = f_bbx.readline().strip(' \n').split(' ')
bbx = [int(bbx_info[i]) for i in range(len(bbx_info))]
#x1, y1, w, h, blur, expression, illumination, invalid, occlusion, pose
if bbx[7]==0:
bbxes.append(bbx)
writexml(idx, head, bbxes, tailstr)
shutil.copyfile(all_path('WIDER_' + datatype + '/images/'+filename), all_path('JPEGImages/%06d.jpg' % (idx)))
f.write('%06d\n' % (idx))
idx +=1
f.close()
f_bbx.close()
return idx
# 打乱样本
def shuffle_file(filename):
f = open(filename, 'r+')
lines = f.readlines()
random.shuffle(lines)
f.seek(0)
f.truncate()
f.writelines(lines)
f.close()
if __name__ == '__main__':
clear_dir()
idx = 1
idx = excute_datasets(idx, 'train')
idx = excute_datasets(idx, 'val')