源码地址:https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3.5
在...\Faster-RCNN\data目录下,将VOCDevkit2007替换成自己的数据集。
我的数据集是VOC形式的,但是没有ImageSets/Main,根据下面的代码生成就好了
ImageSets:
这个文件夹下面有三个文件夹,主要用到的就是Main文件,其余的两个文件我们不用管它,打开Main文件可以看到有好多txt文件,那么在制作我们自己的Main时候真正用到的文件只有4个:
test.txt:测试集 train.txt:训练集 val.txt:验证集 trainval.txt:训练和验证集
用python代码生成:
import os
import random
trainval_percent = 0.66
train_percent = 0.5
xmlfilepath = 'Annotations' # 绝对路径
txtsavepath = 'ImageSets\Main' # 生成的四个文件的存储路径
total_xml = os.listdir(xmlfilepath)
num=len(total_xml)
list=range(num)
tv=int(num*trainval_percent)
tr=int(tv*train_percent)
trainval= random.sample(list,tv)
train=random.sample(trainval,tr)
ftrainval = open('ImageSets/Main/trainval.txt', 'w')
ftest = open('ImageSets/Main/test.txt', 'w')
ftrain = open('ImageSets/Main/train.txt', 'w')
fval = open('ImageSets/Main/val.txt', 'w')
for i in list:
name=total_xml[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest .close()
修改代码
1.lib/datasets/pascal_voc 33行,将下列代码注释掉:
# self._classes = ('__background__', # always index 0
# 'aeroplane', 'bicycle', 'bird', 'boat',
# 'bottle', 'bus', 'car', 'cat', 'chair',
# 'cow', 'diningtable', 'dog', 'horse',
# 'motorbike', 'person', 'pottedplant',
# 'sheep', 'sofa', 'train', 'tvmonitor')
除第一行的 '__background__' (背景)保留外,剩下的根据Annotation中的标签名字更改,如果Annotations中的标签名字首字母是大写,在self._classes中也要是小写。
2.在...\Faster-RCNN\data目录下,检查是否有个叫cache的文件夹,每次在训练模型前,建议清空这个文件夹里面的东西。
运行 train.py
运行成功后跑了几百步出现了下面的问题:
error:
1.
image invalid, skipping
Traceback (most recent call last):
File "C:\Users\think\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\script_ops.py", line 157, in __call__
ret = func(*args)
File "D:\Faster-RCNN-TensorFlow-Python3.5-master\lib\layer_utils\proposal_target_layer.py", line 47, in proposal_target_layer
rois_per_image, _num_classes)
File "D:\Faster-RCNN-TensorFlow-Python3.5-master\lib\layer_utils\proposal_target_layer.py", line 135, in _sample_rois
raise Exception()
Exception
image invalid, skipping
修改:
将lib/config/config.py里的86行的 roi_bg_threshold_low 的 0.1 改成 0.0
2.再运行了几百步后出现下面的错误:
RuntimeWarning: invalid value encountered in log targets_dw = np.log(gt_widths / ex_widths)
loss = nan
修改:
lib/datasets/pascal_voc.py文件,找到168行,将168行至171行每一行后面的-1删除,如下所示:
x1 = float(bbox.find(‘xmin’).text)
y1 = float(bbox.find(‘ymin’).text)
x2 = float(bbox.find(‘xmax’).text)
y2 = float(bbox.find(‘ymax’).text)
lib/datasets/imdb.py文件,找到117行,将117行和118行后面的 -1删除,如下所示:
boxes[:, 0] = widths[i] - oldx2
boxes[:, 2] = widths[i] - oldx1
就可以成功运行啦。
测试:
更改
1.demo.py 32行 class的种类,和lib/datasets/pascal_voc 中更改的一样
2.demo.py 156行
net.create_architecture(sess, "TEST", 2, # 种类需要更改 tag='default', anchor_scales=[8, 16, 32])
修改种类数+1(背景)
3.再次修改自己数据集训练后的模型的参数。
# 读取训练模型的参数 tfmodel = r'E:\Faster-RCNN\default\voc_2007_trainval\default\vgg16_faster_rcnn_iter_40000.ckpt'
参考:
https://blog.csdn.net/Christine__xu/article/details/89297070
https://blog.csdn.net/ksws0292756/article/details/80702704