循环网络视觉attention机制 论文笔记

这是读了“Recurrent Models of Visual Attention ”这篇论文后的总结

这种注意机制的方法是使用强化学习来预测一个需要专注的近似位置。这听起来更像人类的注意力,这就是视觉注意递归模型( Recurrent Models of Visual Attention)所完成的事情。然而,强化学习模型不能用反向传播进行首尾相连的训练,因此它们并不能广泛应用于NLP问题中。(这句是引用了某篇博客的)

  1. 提出原因:对于较大的图像,CNN的计算开销会很大。这里提出一种新的RNN模型,通过对输入适应性的选择某个部分或位置,并对这些部分或位置进行高像素处理,进而提取图像或视频。这种想法的来源是,人类视觉系统处理某个场景时不会一次处理整个场景,而是有选择的将注意力集中在某部分视觉空间上,获取其中关键的信息,并随着时间推移结合来自不同位置的信息建立该场景的内部表征,从而指导眼球的运动(也就是视线看向下一个关键位置可以这么理解)和决策制定。
  2. 原理:模型将视觉场景的基于注意力机制的处理过程视为一个控制问题,从而引入了强化学习。模型使用的是递归神经网络(RNN),它依次处理输入,每次关注图像(或视频帧)中的不同位置(模型在每一步根据过去信息和任务选择下一个要处理的位置),并增量的组合来自这些位置的信息,以建立场景或环境的内部表征。    整个过程使用了反向传播来训练神经网络,并使用强化学习的策略梯度方法来解决控制问题引起的不可微性

  3. 结构:本文将注意力机制问题考虑为了agent与环境交互的序列决策过程,每个时刻agent通过带宽度限制的感知观测环境,而并没有观测到整个环境。整个模型结构如下图所示:

        循环网络视觉attention机制 论文笔记_第1张图片

         A是Glimbse Sensor,输入关注的位置lt和图像xt提取出xt的视觉表征ρ(xt, lt-1),整个过程会对xt中lt-1附近的像素进行高分辨率编码(就是对这部分看的更加细致,分辨率越高画质越好,细节表现的越好),原离lt-1的则进行低分辨率编码,从而得到一个比原始输入xt更加低维的结果作为一个glimbse。

         B是Glimbse Network,它以A输出的glimbse和关注位置lt-1作为输入,输出的是glimbse feature vector ,gt。

         C是整个模型结构,整个结构同RNN原结构类似,首先Glimbse network也就是C中fg(θg)从输入提取出一个glimbse feature vector ,gt。然后Core network也就是fh(:; θh),将gt作为输入并结合前一步的内部表征ht-1产生当前时间 步的内部表征ht。然后location network fl(:; θl)和action network   fa(:; θa)分别使用颞部状态ht产生下一个关注位置lt和action at。

           internal state:agent保存有内部状态ht,它过去observersion的一个总结。它编码了agent的环境信息,也就是原理中所说的场景或环境的内部表征。通过它来产生感知分配lt和环境动作at。

       actions:每一步agent有两个动作at和lt。lt决定感知分配,at会影响环境状态。

       reward:每执行一个动作agent就会接受一个新的visual observation和一个奖赏信号r。agent的目标是最大化累积奖赏。例如在图像识别(分类)任务中,当图像被正确分类则奖赏为rT=1(最后一个奖赏)否则就是0.

事实上这里agent与环境交互的过程是一个POMDP问题即(Partially Observationable Markov Decision Process),部分可观测马尔科夫决策过程。这里的状态我们使用的是agnet与环境的交互史来表示即s1:t = x1; l1; a1; : : : xt-1; lt-1; at-1; xt,而要学习的策略是π((lt; at)js1:t; θ)。

         4.训练:agent的优化参数包括θ = (fθg; θh; θa),其中没有θl,在本文中位置动作是通过时间t处的位置网络参数化的分布中随机选择的,优化目标是最大化累积奖赏:

                                           

         然后根据policy gradient的理论(这一部分内容可以看Reinforcement Learning:An Introduction(November 5,2017)第13章269-272页内容),我们可以得到下面公式:

  

然后为了降低方差,使用了REINFORCE with baseline算法:

       

 

           

    


 

   

         

 

 

你可能感兴趣的:(深度学习)