YOLO:You Only Look Once(你只需看一次)
Detection as Regression
把系统将输入图像分割成S×S网格。
如果一个物体的中心落在一个网格单元中,该网格单元负责检测该物体。每个网格单元格预测B bounding boxes和这些框的confidence scores。这些置信度评分(confidence scores)反映了模型对boxes中包含对象的置信度,以及模型对盒子的预测准确度。我们 定义 confidence 为Pr(Object) ∗ IOUtruth pred 。如果该单元格中不存在对象,则置信度得分应该为零。否则,将置信度得分等于预测框与范围真实值之间的联合(IOU)的交集。
每个边界框由5个预测组成:x、y、w、h和置信度。(x,y坐标表示框相对于网格单元边界的中心。宽度和高度相对于整个图像进行预测。最后,置信预测表示预测框和任何地面真值框之间的IOU。
每个网格单元还预测C条件类概率Pr(Classi对象)。这些概率是在包含对象的网格单元上确定的。只预测每个网格单元的一组类概率,而不考虑框B的数量。在测试时,将条件类概率与单个框的置信度预测相乘,得到每个框的类特定置信度分数。这些分数编码了类出现在框中的概率以及预测框与对象的匹配程度。
对于PASCAL VOC上的YOLO,我们使用S=7,B=2。PASCAL VOC有20个标记类,所以C=20。
最后的预测是7×7×30张量。
调用yolo的方法
1 基于AIstudio 的PaddleDetection使用GPU训练并预测,看另一篇blog
https://blog.csdn.net/shuzip/article/details/103478636
2 使用Darknet的方法
直接调用COCO(Common Objects in Context)数据集上预训练好的模型和权重。
darknet官方网站:https://pjreddie.com/darknet/
import cv2
import matplotlib.pyplot as plt
from utils import *
from darknet import Darknet
# 指定cfg文件的位置,这个文件包含了模型的结构
cfg_file = './cfg/yolov3.cfg'
# 指定weights文件的位置,这个文件包含了模型的权重
weight_file = './weights/yolov3.weights'
# 指定COCO数据集类标签
namesfile = 'data/coco.names'
# 载入模型结构
m = Darknet(cfg_file)
# 载入模型权重
m.load_weights(weight_file)
# 载入COCO类标签
class_names = load_class_names(namesfile)
m.print_network()#打印权重
# 指定绘图大小
plt.rcParams['figure.figsize'] = [24.0, 14.0]
#载入图像
img = cv2.imread('./images/dog.jpg')
# 转换为RGB颜色通道
original_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 缩放图像为模型要求的输入大小
resized_image = cv2.resize(original_image, (m.width, m.height))
# 绘制图像
plt.subplot(121)
plt.title('Original Image')
plt.imshow(original_image)
plt.subplot(122)
plt.title('Resized Image')
plt.imshow(resized_image)
plt.show()
nms_thresh = 0.6 #设置非极大值抑制NMS的阈值
iou_thresh = 0.4 # 设置交并比IOU阈值
# 设置图像大小
plt.rcParams['figure.figsize'] = [24.0, 14.0]
# 载入图像
img = cv2.imread('./images/dog.jpg')
# 转为RGB颜色通道
original_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 缩放图像为模型输入维度
resized_image = cv2.resize(original_image, (m.width, m.height))
# 设置IOU阈值
iou_thresh = 0.4
# 设置NMS阈值
nms_thresh = 0.6
# 检测图像中的物体
boxes = detect_objects(m, resized_image, iou_thresh, nms_thresh)
# 输出检测到的物体及置信度
print_objects(boxes, class_names)
# 可视化图像、框,及分类结果
plot_boxes(original_image, boxes, class_names, plot_labels = True)