现在用YOLO v2加载训练好的COCO数据集权重模型进行图片目标预测,有关细节部分就不赘述了。
这里放上YOLO v2的论文地址:
YOLO v2论文地址:https://arxiv.org/pdf/1612.08242.pdf
pip install opencv-python==3.4.9.31 -i https://pypi.tuna.tsinghua.edu.cn/simple
清华镜像安装TensorFlow:(这里强烈建议大家安装CPU版本的TensorFlow,GPU版本需要提前配置好cuda和cudnn,CPU版本的用于图片目标检测足够了)
安装CPU版本:
pip install tensorflow==1.8.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
安装GPU版本的TensorFlow:
pip install tensorflow-gpu==1.8.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
文件夹里面文件分布如下图所示:
其中Main.py就是用于图像目标检测的程序,在该程序中,需要修改相应的读取图片路径、输出路径和模型路径。
第30行修改读入图片的路径,这里要注意最好用图片的绝对路径
image_file = '1.jpg' #读取输入图片
第55行修改图片的输出保存路径
cv2.imwrite("detection_result.jpg", img_detection)
第44行是读取模型权重文件的路径
model_path = "./yolo2_model/yolo2_coco.ckpt"
下面是完整的Main.py中的代码:
# %load Main.py
# --------------------------------------
# @Time : 2018/5/16$ 17:17$
# @Author : KOD Chen
# @Email : [email protected]
# @File : Main$.py
# Description :YOLO_v2主函数.
# --------------------------------------
import numpy as np
import tensorflow as tf
import cv2,os
from PIL import Image
import matplotlib.pyplot as plt
from model_darknet19 import darknet
from decode import decode
from utils import preprocess_image, postprocess, draw_detection
from config import anchors, class_names
#%matplotlib inline
def main():
input_size = (416,416)
image_file = '1.jpg' #读取输入图片
image = cv2.imread(image_file)
image_shape = image.shape[:2] #只取wh,channel=3不取
# copy、resize416*416、归一化、在第0维增加存放batchsize维度
image_cp = preprocess_image(image,input_size)
# 【1】输入图片进入darknet19网络得到特征图,并进行解码得到:xmin xmax表示的边界框、置信度、类别概率
tf_image = tf.placeholder(tf.float32,[1,input_size[0],input_size[1],3])
model_output = darknet(tf_image) # darknet19网络输出的特征图
output_sizes = input_size[0]//32, input_size[1]//32 # 特征图尺寸是图片下采样32倍
output_decoded = decode(model_output=model_output,output_sizes=output_sizes,
num_class=len(class_names),anchors=anchors) # 解码
model_path = "./yolo2_model/yolo2_coco.ckpt"
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,model_path)
bboxes,obj_probs,class_probs = sess.run(output_decoded,feed_dict={
tf_image:image_cp})
# 【2】筛选解码后的回归边界框——NMS(post process后期处理)
bboxes,scores,class_max_index = postprocess(bboxes,obj_probs,class_probs,image_shape=image_shape)
# 【3】绘制筛选后的边界框
img_detection = draw_detection(image, bboxes, scores, class_max_index, class_names)
cv2.imwrite("detection_result.jpg", img_detection)
img_detection = cv2.cvtColor(img_detection, cv2.COLOR_RGB2BGR)
plt.figure(figsize=(10,10))
plt.imshow(img_detection) #界面显示
#print('YOLO_v2 detection has done!')
print('YOLO_v2 检测完成!')
#cv2.imshow("detection_results", img_detection)
#cv2.waitKey(0)
plt.show()
if __name__ == '__main__':
main()
运行程序,就可以对图片进行目标检测了。更换图片路径,可以对不同的图片进行目标检测了。