Few-shot Object Detection via Feature Reweighting

Few-shot Object Detection via Feature Reweighting

  • 模型组成
    • Feature Extractor
    • Reweighting Module
    • Prediction Layer
  • 训练策略

摘要:这是ICCV2019的一片文章,主要是将Few-Shot Learning用于物体检测上面。其核心思是使用具有大量标签的base类训练一个特征调整模块,通过这个模块可以使用许多类的底层特征对需要检测图片的特征进行调整。这些底层特征能够在一定程度上反应所有类的通性,也可以理解为组成物体的属性(虽然我们检测的物体可能不属于同一类,但是在属性层面还是有很多相通的)然后,再将网络fine tune到小样本检测中去。

模型组成

模型三个模块组成:Feature Extractor;Reweighting Module;Prediction Layer

Feature Extractor

其中,Feature Extractor主要用来提取待检测图片(query image)的特征。对于一张图片,使用YoLoV2的backbone(DarkNet)来提取特征,如下图所示:

Few-shot Object Detection via Feature Reweighting_第1张图片
图片经过Feature Extractor模块提取特征,最后变为w×h×m维的F特征。其中m可以理解为不同属性的特征,如上图中所示,每一个w×h的方片都可以代表一个维度的属性特征,这些属性是存在所有类别中的通用属性。

Reweighting Module

Reweighting Module主要是使用从base类中提取的属性特征对待检测图片特征进行调整。既然,我们需要提取base的属性特征,所以base类的标签一定是可知的,只有将标签融合进网络,网络才能知道怎么提取基础属性特征。至于怎么提取,这就是深度学习需要做的事情了。所以如下图所示:

Few-shot Object Detection via Feature Reweighting_第2张图片
在训练集中随机抽取N个类中的一张图片作为Support Set。将每张图片中的标签信息转为mask,然后将mask拼接到RGB图像中去,最后形成一个w×h×4的输入。然后输入到卷积层中提取属性信息。最后每一个类的对应图片会卷积为一个m维的向量,如上图所示。 现在,我们有了所有类的属性向量,也有了待检测图片的属性向量。接下来就是使用前者对后者进行调整,使后者能够使用前者的信息。文章中的做法是直接使用前者对后者进行乘法运算,也就是直接乘乘上去。具体怎么实现,应该是使用前者作为1×1 depth-wise convolution 卷积核的权重来卷积后者。最后会得到待检测图片使用不同类别调整过后的特征,如下图所示:
Few-shot Object Detection via Feature Reweighting_第3张图片
图片中深色的方片代表待检测图片和这一类图片在这个属性维度上比较接近(我的理解)。

Prediction Layer

然后就简单了,将得到向量特征输入到常用的检测网络中计算待检测图像的一些损失。包括好多好多损失,都是物体检测中常用的损失,所以网络的整体框架也就出来了:
Few-shot Object Detection via Feature Reweighting_第4张图片

训练策略

模型会首先在标签足够的base类上进行训练,使得三个模块都能够正确的提取对应的特征而且能够协同工作后,在迁移到小样本上进行训练。
(1)所以模型的训练分为两步,首先在base类中,产生足够的Task来进行模型的第一步训练。
(2)然后,将base类和小样本的novel类进行融合。在base+novel类上进行最后的训练。
需要注意的是,在第二步的训练中。因为novel类的标签很少,通常只有k个(k-shot),所以为了类别平衡,相对应的也只是使用base类中每个类别的k个标签。这样才能真正达到小样本学习的效果。
Few-shot Object Detection via Feature Reweighting_第5张图片
待完成所有训练后,模型已经可以检测出新的小样本了。模型在voc和coco数据集上进行了检验,最后贴两个检验效果:
Few-shot Object Detection via Feature Reweighting_第6张图片
Few-shot Object Detection via Feature Reweighting_第7张图片

你可能感兴趣的:(few,shot,learning,深度学习,神经网络)