目标检测算法Faster RCNN的损失函数以及如何训练?

前面我们学习了Faster RCNN的原理流程,特别是RPN网络的原理,详情如下:
目标检测算法Faster RCNN详解
目标检测算法Fast RCNN详解
目标检测算法SPP-Net详解
目标检测算法R-CNN详解

今天我们主要看下Faster RCNN算法的损失函数以及如何训练?

损失函数:

从上一期Faster RCNN的算法原理上,我们知道Faster RCNN算法有两部分,一个是用来提取候选框的RPN网络,一个是最后检测目标的分类回归网络。通过学习,我们知道RPN网络在提取候选框的时候有两个任务,一个是判断该anchor产生的候选框是否是目标的二分类任务,另一个是对该候选框进行边框回归的回归任务。
Faster RCNN最后的目标检测网络同样也有两个任务,跟RPN网络类似,一个是判断RPN网络产生的候选框框住的物体是具体哪一类物体的分类任务,另一个是对该候选框进行回归的回归任务。
既然两个网络都是多任务网络,那么,我们先看看RPN网络的损失函数是怎么样的?先上RPN网络的总体损失函数,接下来分析,如下(公式可左右滑动):

上面损失函数可以明显的分成两部分,+号左边为分类的损失值,右边为回归的损失值。逐个看,先考虑分类的loss

分类loss

上式中 anchor预测为目标的概率, ground truth,如下:

如果anchor为正,则ground truth标签 1,否则为0
而分类的loss函数为交叉熵,如下:

0时:

1时:

其中 mini-batch大小。熟悉机器学习损失函数的,这点还是很好理解的,不熟悉的可在【智能算法】公众号回复【机器学习】进行学习。接下来,我们看下回归部分的损失。

回归loss:

上面总的损失函数中 表示bounding box4个参数, 是与positive anchor对应的ground truth的4个坐标参数,当 0时,回归的loss0,当 1时,才需要考虑回归的loss
损失函数中的 如下:

其中R

是回归loss的权重,例如 , anchor位置的数量,这里大约为40*60=240

计算如下:

其中 是预测框中心的坐标和宽高, anchor box中心的坐标和宽高。

其中 是真实标注框中心坐标和宽高。

到这里可能会有些迷糊了,有 ,又有 ,还有 ,这里简单说下,这三个框框啥意思,首先 RPN网络预测出来框框, anchor产生的候选框,而 这个是物体真实标注的框框。损失函数的目的就是让R=0,也就是 ,那么也就是说尽量让,即达到预测框跟真实标注重合。

而算法最后的分类回归网络的损失函数则和RPN的损失函数很是相似,输出层分类的losssoftmax交叉熵,回归的lossRPN的回归loss一样。

如何训练?

这个Faster RCNN模型的训练有些复杂,我们还是先把上期的这个算法流程图贴上,有助于下面训练流程的理解,如下:目标检测算法Faster RCNN的损失函数以及如何训练?_第1张图片从上图,我们可以看出,整个算法的两个网络(RPN和最终的分类回归网络)共用同一个卷积网络。那么该如何才能达到共用呢?这里分四步来训练:

  1. ImageNet模型初始化,先独立训练一个RPN网络;

  2. 仍然用ImageNet模型初始化,但是使用上一步训练好的RPN网络产生的候选框作为输入,训练一个Fast-RCNN网络;

  3. 用上一步的Fast-RCNN网络模型重新初始化RPN网络,但是不更新Fast-RCNN网络模型的共享卷积层,只更新RPN网络的特有层;

  4. 用第2步的Fast-RCNN网络模型重新初始化,但是不更新Fast-RCNN网络模型的共享卷积层,使用第3步新的RPN网络重新产生候选框做输入,训练一个Fast-RCNN网络。以此达到RPN网络和最终的检测网络共享卷积层。

相当于是先用一个ImageNet模型初始化训练,然后再用训练好的模型去微调两个网络。至此,我们已经了解了Faster RCNN的损失函数和训练过程。下期我们将继续学习常见的目标检测模型SSD算法。

目标检测算法Faster RCNN的损失函数以及如何训练?_第2张图片

你可能感兴趣的:(目标检测算法Faster RCNN的损失函数以及如何训练?)