我使用的代码地址是: Faster-rcnn.
代码结构如图:
准备数据集以及训练模型的过程网上教程很多,这里记录我在检测训练好的模型过程中遇到的问题。
原代码中未保存检测后的图片以及坐标值、置信度等,因此我修改了demo.py。
原代码:
CLASSES = ('__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',), 'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
DATASETS = {'pascal_voc': ('voc_2007_trainval',), 'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
修改后:
CLASSES = ('__background__',
'car')
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_135000.ckpt',), 'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
DATASETS = {'pascal_voc': ('voc_2007_trainval',), 'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
我用的是KITTI数据集,只检测车辆,因此修改类别,同时要修改选择训练过程中保存的权重,我用的backbone是vgg16,故只修改vgg16那块的权重即可。
原代码:
def vis_detections(im, class_name, dets, thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return
im = im[:, :, (2, 1, 0)]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal')
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=3.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
ax.set_title(('{} detections with '
'p({} | box) >= {:.1f}').format(class_name, class_name,
thresh),
fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.draw()
修改后:
def vis_detections(im, class_name, dets, image_name,thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return
im = im[:, :, (2, 1, 0)]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal')
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=3.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
if(class_name == '__background__'): # 添加了保存坐标和置信度的代码
fw = open('./data/demo/test/result.txt','a') #保存结果的文件,下同
fw.write(str(image_name)+' '+class_name+' '+str(float(score))+' '+str(int(bbox[0]))+' '+str(int(bbox[1]))+' '+str(int(bbox[2]))+' '+str(int(bbox[3]))+'\n')
fw.close()
elif(class_name == 'car'): # 有多少类别就添加多少个
fw = open('./data/demo/test/result.txt','a')
fw.write(str(image_name)+' '+class_name+' '+str(float(score))+' 'str(int(bbox[0]))+' '+str(int(bbox[1]))+' '+str(int(bbox[2]))+' '+str(int(bbox[3]))+'\n')
fw.close()
ax.set_title(('{} detections with '
'p({} | box) >= {:.1f}').format(class_name, class_name,
thresh),
fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.draw()
保存的txt文件最好提前创建,否则会报错,也可以改为自动生成(这里我懒了,没有改)。
原代码:
im_names = ['000456.jpg', '000457.jpg', '000542.jpg', '001150.jpg',
'001763.jpg', '004545.jpg']
for im_name in im_names:
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
print('Demo for data/demo/{}'.format(im_name))
demo(sess, net, im_name)
plt.show()
需要修改原代码中最后几行,它把检测的图片全都列出来了,我的测试图片比较多,所以写成了文件夹读取,原代码中并没有保存检测后的图片,而是将它show出来。我是在linux服务器跑的程序,不能可视化,所以将检测后的图片都保存了。
修改后的代码:
im_names = os.listdir('测试图片的路径,可自己添加')
fr = open('./data/demo/test/test1.txt','r') # 这一句本来想将测试的图片名写在txt文件里,但貌似没生效
for im_name in im_names:
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
print('Demo for data/demo/{}'.format(im_name))
demo(sess, net, im_name)
plt.savefig(im_name) # 保存检测后的图片,路径可自行指定,这是直接保存在根目录下,im_name是图片名
# plt.show() # linux服务器无法进行可视化展示,因此注释掉。
注意: matplotlib的可视化无法显示时,在程序最开始添加
import matplotlib
matplotlib.use('Agg')
否则直接将show注释可能会报错
全部修改完毕后,成功运行界面:
检测效果:
保存的txt文件:
从左到右分别是: 图片名称、类别、置信度、检测出的四个坐标
以上是我在检测过程中遇到的问题,如有不足欢迎讨论~~~~~加油呀~