CNN-RNN: A Unified Framework for Multi-label ImageClassification

文章目录

      • CNN-RNN: A Unified Framework for Multi-label ImageClassification
        • CVPR 2016
        • 介绍
        • 方法
          • LSTM
          • 模型
          • 推理

CNN-RNN: A Unified Framework for Multi-label ImageClassification

CVPR 2016

在本文中,我们利用递归神经网络(RNN)来解决标签依赖性。提出的CNN-RNN框架学习联合图像标签嵌入来表征语义标签依赖性以及图像标签相关性,并且可以端对端地训练。

介绍

传统方法

  1. 标签相关性
    大多数现有的工作基于图形模型,
  • 当处理大量标签时,这些成对概率的参数可能非常大,而如果标签具有高度重叠的含义,则许多参数是多余的。
  • 大多数这些方法要么不能模拟高阶相关
  • 要么牺牲计算复杂性来模拟更复杂的标签关系
  1. CNN
    为了避免过度拟合等问题,以前的方法通常假设所有分类器共享相同的图像特征,当使用相同的图像特征来预测多个标签时,图像中较小的对象很容易被忽略或难以独立识别

通过将RNN与CNN结合,其背后的想法是隐含地调整图像中的注意区域,以便CNN在预测不同标签时可以将注意力集中在图像的不同区域上。

方法

LSTM

详细介绍请看Google
CNN-RNN: A Unified Framework for Multi-label ImageClassification_第1张图片

模型

CNN-RNN: A Unified Framework for Multi-label ImageClassification_第2张图片
作者得思路就是,靠卷积网络来提取图片得特征,而靠LSTM网络对CNN进行一个目标导向,输入一张图片给CNN,它所看到的视觉是没有目标的,但是如果我们使用LSTM对其进行导向,或许CNN会对LSTM的导向所影响。

对于一个图片的标签k是 e k = [ 0 , 0 , . . . 0 , 1 , 0 , 0 , 0 ] e_k = [0,0,...0,1,0,0,0] ek=[0,0,...0,1,0,0,0],乘以一个向量进行降维(不懂)
w k = U l ⋅ e k w _ { k } = U _ { l \cdot e _ { k } } wk=Ulek

对降维后的数据过一个RNN,得到输出
o ( t ) = h o ( r ( t − 1 ) , w k ( t ) ) , r ( t ) = h r ( r ( t − 1 ) , w k ( t ) ) o ( t ) = h _ { o } \left( r ( t - 1 ) , w _ { k } ( t ) \right) , r ( t ) = h _ { r } \left( r ( t - 1 ) , w _ { k } ( t ) \right) o(t)=ho(r(t1),wk(t)),r(t)=hr(r(t1),wk(t))

对RNN输出和CNN输出进行结合
x t = h ( U o x o ( t ) + U I x I ) x _ { t } = h \left( U _ { o } ^ { x } o ( t ) + U _ { I } ^ { x } I \right) xt=h(Uoxo(t)+UIxI)

因为之前是经过降维的,现在要对输出x_t做相反操作
s ( t ) = U l T x t s ( t ) = U _ { l } ^ { T } x _ { t } s(t)=UlTxt

然后在预测具体的最后一层用个softmax normalization 处理下得到每个标签的概率

推理

现在有一个问题是在训练阶段标签要按什么样的顺序输入(因为一次输入是一个一个标签输入),作者提出按照每个标签的先验概率排序进行输入,这种策略会有很好的性能。但是如果我们随机的输入标签顺序,这个网路很难收敛

在预测阶段,如果使用贪婪算法:预测了前n个标签,再预测第n+1个。这就会有问题,假如我们第一个预测就错了,之后的都错了。这里使用束搜索。

束搜索先选择束的大小N,挑选出先验概率前N个标签,对于每个标签,输入到RNN中,选出概率最大的N个标签,继续递归,找到概率最大的2个标签,对于一次递归,一共有N*N个路径,选出概率最大的N条路径中第二列N个节点下一次迭代 ,一种重复N次迭代。

对于每一次递归何时停止呢,如果下一个搜索的概率低于当前所有的路径时,我们就停止,在搜索到第三层时,其概率已经低于当前所有可能搜索路径,我们停止。

你可能感兴趣的:(多标签学习,神经网络,分类,机器学习)