YOLO中anchor box的作用(面试必考)

在网络最后的输出中,对于每个grid cell产生3个bounding box,每个bounding box的输出有三类参数:一个是对象的box参数,一共是四个值,即box的中心点坐标(x,y)和box的宽和高(w,h);一个是置信度,这是个区间在[0,1]之间的值;最后一个是一组条件类别概率,都是区间在[0,1]之间的值,代表概率。

假如一个图片被分割成S∗SS*SS∗S个grid cell,我们有B个anchor box,也就是说每个grid cell有B个bounding box, 每个bounding box内有4个位置参数,1个置信度,classes个类别概率,那么最终的输出维数是:S∗S∗[B∗(4+1+classes)]S*S*[B*(4 + 1 + classes)]S∗S∗[B∗(4+1+classes)]。

下面分别具体介绍这三个参数的意义。

1. anchor box

1.1 对anchor box的理解

anchor box其实就是从训练集的所有ground truth box中统计(使用k-means)出来的在训练集中最经常出现的几个box形状和尺寸。比如,在某个训练集中最常出现的box形状有扁长的、瘦高的和宽高比例差不多的正方形这三种形状。我们可以预先将这些统计上的先验(或来自人类的)经验加入到模型中,这样模型在学习的时候,瞎找的可能性就更小了些,当然就有助于模型快速收敛了。以前面提到的训练数据集中的ground truth box最常出现的三个形状为例,当模型在训练的时候我们可以告诉它,你要在grid cell 1附件找出的对象的形状要么是扁长的、要么是瘦高的、要么是长高比例差不多的正方形,你就不要再瞎试其他的形状了。anchor box其实就是对预测的对象范围进行约束,并加入了尺寸先验经验,从而实现多尺度学习的目的。

量化anchor box

要在模型中使用这些形状,总不能告诉模型有个形状是瘦高的,还有一个是矮胖的,我们需要量化这些形状。YOLO的做法是想办法找出分别代表这些形状的宽和高,有了宽和高,尺寸比例即形状不就有了。YOLO作者的办法是使用k-means算法在训练集中所有样本的ground truth box中聚类出具有代表性形状的宽和高,作者将这种方法称作维度聚类(dimension cluster)。细心的读者可能会提出这个问题:到底找出几个anchor box算是最佳的具有代表性的形状。YOLO作者方法是做实验,聚类出多个数量不同anchor box组,分别应用到模型中,最终找出最优的在模型的复杂度和高召回率(high recall)之间折中的那组anchor box。作者在COCO数据集中使用了9个anchor box,我们前面提到的例子则有3个anchor box。

怎么在实际的模型中加入anchor box的先验经验呢?

最终负责预测grid cell中对象的box的最小单元是bounding box,那我们可以让一个grid cell输出(预测)多个bounding box,然后每个bounding box负责预测不同的形状不就行了?比如前面例子中的3个不同形状的anchor box,我们的一个grid cell会输出3个参数相同的bounding box,第一个bounding box负责预测的形状与anchor box 1类似的box,其他两个bounding box依次类推。作者在YOLOv3中取消了v2之前每个grid cell只负责预测一个对象的限制,也就是说grid cell中的三个bounding box都可以预测对象,当然他们应该对应不同的ground truth。那么如何在训练中确定哪个bounding box负责某个ground truth呢?方法是求出每个grid cell中每个anchor box与ground truth box的IOU(交并比),IOU最大的anchor box对应的bounding box就负责预测该ground truth,也就是对应的对象,后面还会提到负责预测的问题。

怎么告诉模型第一个bounding box负责预测的形状与anchor box 1类似,第二个bounding box负责预测的形状与anchor box 2类似?

YOLO的做法是不让bounding box直接预测实际box的宽和高(w,h),而是将预测的宽和高分别与anchor box的宽和高绑定,这样不管一开始bounding box输出的(w,h)是怎样的,经过转化后都是与anchor box的宽和高相关,这样经过很多次惩罚训练后,每个bounding box就知道自己该负责怎样形状的box预测了。这个绑定的关系是什么?就涉及到了anchor box的计算。

1.2 anchor box的计算

前提需要知道,

cxc_{x}cx​和cyc_{y}cy​的坐标是(0,0) (0,1),(0,2),(0,3)…(0,13)

(1,0),(1,1),(1,2),(1,3)…(1,13)等等

bouding box的输出应当为:txt_{x}tx​和tyt_{y}ty​以及twt_{w}tw​和tht_{h}th​

而真实的预测box应当是:bxb_{x}bx​和byb_{y}by​(中心坐标)以及bwb_{w}bw​和bhb_{h}bh​(宽高)

还有就是cxc_{x}cx​和cyc_{y}cy​的每一个都是1,也就是说,每个格子grid cell是以1为一个范围,每个grid cell的大小实际是1∗11*11∗1

刚才说的绑定的关系是什么?就是下面这个公式:

bw=awetwb_w=a_we^{t_w}bw​=aw​etw​bh=ahethb_h=a_he^{t_h}bh​=ah​eth​

其中,awa_{w}aw​和aha_{h}ah​为anchor box的宽和高,

twt_{w}tw​和tht_{h}th​为bounding box直接预测出的宽和高,

bwb_{w}bw​和bhb_{h}bh​为转换后预测的实际宽和高,

这也就是最终预测中输出的宽和高。你可能会想,这个公式这么麻烦,为什么不能用bw=aw∗twb_{w}=a_{w}*t_{w}bw​=aw​∗tw​, bh=ah∗thb_{h}=a_{h}*t_{h}bh​=ah​∗th​这样的公式,我的理解是上面的公式虽然计算起来比较麻烦,但是在误差函数求导后还带有twt_{w}tw​和tht_{h}th​参数,而且也好求导(此观点只是个人推测,需要进一步查证)。

既然提到了最终预测的宽和高公式,那我们也就直接带出最终预测输出的box中心坐标(bx,by)(b_{x},b_{y})(bx​,by​)的计算公式

前面提到过box中心坐标总是落在相应的grid cell中的,所以bounding box直接预测出的txt_{x}tx​和tyt_{y}ty​也是相对grid cell来说的,要想转换成最终输出的绝对坐标,需要下面的转换公式:

bx=σ(tx)+cxb_{x} = \sigma(t_{x}) + c_{x}bx​=σ(tx​)+cx​by=σ(ty)+cyb_{y} = \sigma(t_{y}) + c_{y}by​=σ(ty​)+cy​

其中,σ(tx)\sigma(t_{x})σ(tx​)为sigmoid函数,

cxc_{x}cx​和cyc_{y}cy​分别为grid cell方格左上角点相对整张图片的坐标。

这个公式tx,ty为何要sigmoid一下啊?

作者使用这样的转换公式主要是因为在训练时如果没有将txt_{x}tx​和tyt_{y}ty​压缩到(0,1)区间内的话,模型在训练前期很难收敛。

另外:用sigmoid将txt_{x}tx​和tyt_{y}ty​压缩到[0,1]区间內,可以有效的确保目标中心处于执行预测的网格单元中,防止偏移过多

举个例子,我们刚刚都知道了网络不会预测边界框中心的确切坐标而是预测与预测目标的grid cell左上角相关的偏移txt_{x}tx​和tyt_{y}ty​。如13∗1313*1313∗13的feature map中,某个目标的中心点预测为(0.4,0.7)【都小于1】,它的cxc_{x}cx​和cyc_{y}cy​即中心落入的grid cell坐标是(6,6),则该物体的在feature map中的中心实际坐标显然是(6.4,6.7).这种情况没毛病,但若txt_{x}tx​和tyt_{y}ty​大于1,比如(1.2,0.7)则该物体在feature map的的中心实际坐标是(7.2,6.7),注意这时候该物体中心在这个物体所属grid cell外面了,但(6,6)这个grid cell却检测出我们这个单元格内含有目标的中心(yolo是采取物体中心归哪个grid cell整个物体就归哪个grid celll了),这样就矛盾了,因为左上角为(6,6)的grid cell负责预测这个物体,这个物体中心必须出现在这个grid cell中而不能出现在它旁边网格中,一旦txt_{x}tx​和tyt_{y}ty​算出来大于1就会引起矛盾,因而必须归一化。

最终可以得出实际输出的box参数公式如下,这个也是在推理时将输出转换为最终推理结果的公式:

bx=σ(tx)+cxb_{x}=\sigma(t_{x}) + c_{x}bx​=σ(tx​)+cx​by=σ(ty)+cyb_{y}=\sigma(t_{y}) + c_{y}by​=σ(ty​)+cy​bw=awetwb_{w}= a_{w}e^{t_{w}}bw​=aw​etw​bh=ahethb_{h}= a_{h}e^{t_{h}}bh​=ah​eth​

其中,

cxc_{x}cx​和cyc_{y}cy​是网格grid cell的左上角坐标是:(0,0) (0,1),(0,2),(0,3)…(0,13)

(1,0),(1,1),(1,2),(1,3)…(1,13)等等

bouding box的输出应当为:txt_{x}tx​和tyt_{y}ty​以及twt_{w}tw​和tht_{h}th​

而真实的预测box应当是:bxb_{x}bx​和byb_{y}by​以及bwb_{w}bw​和bhb_{h}bh​

bxb_{x}bx​和byb_{y}by​以及bwb_{w}bw​和bhb_{h}bh​:预测出来的box的中心坐标和宽高

下图中的pwp_wpw​实际上就是上面的awa_waw​,php_hph​实际上就是上面的aha_hah​

训练

关于box参数的转换还有一点值得一提,作者在训练中并不是将tx、ty、tw和tht_{x}、t_{y}、t_{w}和t_{h}tx​、ty​、tw​和th​转换为bx、by、bwb_{x}、b_{y}、b_{w}bx​、by​、bw​和bhb_{h}bh​后与ground truth box的对应参数求误差而是使用上述公式的逆运算将ground truth box的参数转换为与tx、ty、twt_{x}、t_{y}、t_{w}tx​、ty​、tw​和th对应的gx、gy、gwt_{h}对应的g_{x}、g_{y}、g_{w}th​对应的gx​、gy​、gw​和ghg_{h}gh​,然后再计算误差。

也就是说,我们训练的输出是:tx、ty、twt_{x}、t_{y}、t_{w}tx​、ty​、tw​和tht_{h}th​,那么在计算误差时,也是利用真实框的tˆx、tˆy、tˆw\hat t_{x}、\hat t_{y}、\hat t_{w}t^x​、t^y​、t^w​和tˆh\hat t_{h}t^h​这几个值计算误差。

所以需要求解tˆx、tˆy、tˆw\hat t_{x}、\hat t_{y}、\hat t_{w}t^x​、t^y​、t^w​和tˆh\hat t_{h}t^h​:

对于上面的公式:

bx=σ(tx)+cxb_{x}=\sigma(t_{x}) + c_{x}bx​=σ(tx​)+cx​by=σ(ty)+cyb_{y}=\sigma(t_{y}) + c_{y}by​=σ(ty​)+cy​bw=awetwb_{w}= a_{w}e^{t_{w}}bw​=aw​etw​bh=ahethb_{h}= a_{h}e^{t_{h}}bh​=ah​eth​

我们可以知道其中,bx、by、bwb_{x}、b_{y}、b_{w}bx​、by​、bw​和bhb_{h}bh​实际上就是预测出来的框box的中心坐标和宽高,那么如果预测的非常准确,需要真实框的gx、gy、gwg_{x}、g_{y}、g_{w}gx​、gy​、gw​和ghg_{h}gh​坐标应当为:(gx、gy、gwg_{x}、g_{y}、g_{w}gx​、gy​、gw​和ghg_{h}gh​实际上是实际框的中心坐标和宽高)

gx=σ(tx)+cxg_{x}=\sigma(t_{x}) + c_{x}gx​=σ(tx​)+cx​gy=σ(ty)+cyg_{y}=\sigma(t_{y}) + c_{y}gy​=σ(ty​)+cy​gw=awetwg_{w}= a_{w}e^{t_{w}}gw​=aw​etw​gh=ahethg_{h}= a_{h}e^{t_{h}}gh​=ah​eth​

由此可以得到,真实框的tˆx、tˆy、tˆw\hat t_{x}、\hat t_{y}、\hat t_{w}t^x​、t^y​、t^w​和tˆh\hat t_{h}t^h​

计算中由于sigmoid函数的反函数那计算,所以并没有计算sigmoid的反函数,而是计算输出对应的sigmoid函数值。

σ(tˆx)=gx−cx\sigma(\hat t_{x}) = g_x - c_{x}σ(t^x​)=gx​−cx​σ(tˆy)=gy−cy\sigma(\hat t_{y}) = g_y - c_{y}σ(t^y​)=gy​−cy​tˆw=log(gw/aw)\hat t_{w} = \log(g_{w} / a_{w})t^w​=log(gw​/aw​)tˆh=log(gh/ah)\hat t_{h} = \log(g_{h} / a_{h})t^h​=log(gh​/ah​)

这样,我们就可以根据训练的输出σ(tx)、σ(ty)、tw\sigma(t_{x})、\sigma(t_{y})、t_{w}σ(tx​)、σ(ty​)、tw​和tht_{h}th​以及真实框的值σ(tˆx)、σ(tˆy)、tˆw\sigma(\hat t_{x})、\sigma(\hat t_{y})、\hat t_{w}σ(t^x​)、σ(t^y​)、t^w​和tˆh\hat t_{h}t^h​求出误差了。

2. 置信度(confidence)

还存在一个很关键的问题:在训练中我们挑选哪个bounding box的准则是选择预测的box与ground truth box的IOU最大的bounding box做为最优的box,但是在预测中并没有ground truth box,怎么才能挑选最优的bounding box呢?这就需要另外的参数了,那就是下面要说到的置信度。

置信度是每个bounding box输出的其中一个重要参数,作者对他的作用定义有两重

一重是:代表当前box是否有对象的概率Pr(Object)P_{r}(Object)Pr​(Object),注意,是对象,不是某个类别的对象,也就是说它用来说明当前box内只是个背景(backgroud)还是有某个物体(对象);

另一重:表示当前的box有对象时,它自己预测的box与物体真实的box可能的IOUtruthpredIOU_{pred}^{truth}IOUpredtruth​的值,注意,这里所说的物体真实的box实际是不存在的,这只是模型表达自己框出了物体的自信程度。

以上所述,也就不难理解作者为什么将其称之为置信度了,因为不管哪重含义,都表示一种自信程度:框出的box内确实有物体的自信程度和框出的box将整个物体的所有特征都包括进来的自信程度。经过以上的解释,其实我们也就可以用数学形式表示置信度的定义了:

Cji=Pr(Object)∗IOUtruthpredC_{i}^{j} = P_{r}(Object) * IOU_{pred}^{truth}Cij​=Pr​(Object)∗IOUpredtruth​

其中,CjiC_{i}^{j}Cij​表示第i个grid cell的第j个bounding box的置信度。

那么如何训练CjiC_{i}^{j}Cij​?

训练中,Cˆji\hat C_{i}^{j}C^ij​表示真实值,Cˆji\hat C_{i}^{j}C^ij​的取值是由grid cell的bounding box有没有负责预测某个对象决定的。如果负责,那么Cˆji=1\hat C_{i}^{j}=1C^ij​=1,否则,Cˆji=0\hat C_{i}^{j}=0C^ij​=0。

下面我们来说明如何确定某个grid cell的bounding box是否负责预测该grid cell中的对象:前面在说明anchor box的时候提到每个bounding box负责预测的形状是依据与其对应的anchor box(bounding box prior)相关的,那这个anchor box与该对象的ground truth box的IOU在所有的anchor box(与一个grid cell中所有bounding box对应,COCO数据集中是9个)与ground truth box的IOU中最大,那它就负责预测这个对象,因为这个形状、尺寸最符合当前这个对象,这时Cˆji=1\hat C_{i}^{j}=1C^ij​=1,其他情况下Cˆji=0\hat C_{i}^{j}=0C^ij​=0。注意,你没有看错,就是所有anchor box与某个ground truth box的IOU最大的那个anchor box对应的bounding box负责预测该对象,与该bounding box预测的box没有关系。

3. 对象条件类别概率(conditional class probabilities)

对象条件类别概率是一组概率的数组,数组的长度为当前模型检测的类别种类数量,它的意义是当bounding box认为当前box中有对象时,要检测的所有类别中每种类别的概率.

其实这个和分类模型最后使用softmax函数输出的一组类别概率是类似的,只是二者存在两点不同:

YOLO的对象类别概率中没有background一项,也不需要,因为对background的预测已经交给置信度了,所以它的输出是有条件的,那就是在置信度表示当前box有对象的前提下,所以条件概率的数学形式为Pr(classi∣Object)P_{r}(class_{i}|Object)Pr​(classi​∣Object);

分类模型中最后输出之前使用softmax求出每个类别的概率,也就是说各个类别之间是互斥的,而YOLOv3算法的每个类别概率是单独用逻辑回归函数(sigmoid函数)计算得出了,所以每个类别不必是互斥的,也就是说一个对象可以被预测出多个类别。这个想法其实是有一些YOLO9000的意思的,因为YOLOv3已经有9000类似的功能,不同只是不能像9000一样,同时使用分类数据集和对象检测数据集,且类别之间的词性是有从属关系的。


来源:https://blog.csdn.net/weixin_43384257/article/details/100974776

侵删

你可能感兴趣的:(YOLO中anchor box的作用(面试必考))