现在NLP中很火的attention机制,其实早在14年Google-DeepMind的Compution Vision文章——Recurrent Models of Visual Attention中出现过了,15年的时候我曾做过一个ppt,介绍这篇文章,现在找不到了。这里我们通过重新梳理,希望能够搞清楚Attention的来龙去脉,有助于加深我们对Attention机制的理解。
14年GoogLeNet把图像分类任务推到了一个新高度,GoogLeNet是基于CNN技术实现的一种深层网络结构。但是,CNN通常使用固定大小的向量作为输入,这有几个缺点:
于是Recurrent Models of Visual Attention提出使用location-wise hard attention mechanism
进行RNN-based的图像分类。其具体做法是从输入图片中随机选择一个子区域去预测一个中间结果,模型既会预测图像标签,还可以定位目标的位置。也就是说attention based RNN model将图像分类和目标检测整合到了一个端到端的模型中。
location-wise hard attention mechanism
相比CNN-based的目标检测网络什么优势呢?
CNN处理目标检测任务时,必须使用一个单独的网络去预测潜在的目标位置,然后对这些位置进行分类,潜在的目标位置往往很多,导致inference的代价非常高。
作为早期CV-attention的开山经典之作,为什么会选择通过RNN完成Attention呢?
跟本文最接近的,采用Attention化处理(attentional processing)的深度学习文章有以下三篇,可见本文提出的Glimpse Network
并非空穴来风,本文的构想是采用RNN在时序上将视觉信息进行聚合,然后决定下一步采取什么动作。学习过程可以实现对序列决策处理的端到端优化,并不需要像以前的目标检测方法那样依赖于greedy action selection。
基于Attention的图片分类、图片生成、图片主题生成、字符识别博客已经介绍的很清楚了,这里引用过来
由于本文的主要目的并非着重介绍RAM这篇文章,因此训练反向传导公式的推导引用自注意力机制之Recurrent Models of Visual Attention,并对其中缺失的推导细节做了补充
整个模型过程可以看做是一个局部马尔科夫决策过程。每个阶段的动作和位置只与上一阶段的动作和位置有关。即展开RNN结构,以时间为序,整个过程可表示为
s 1 : t = x 1 , l 1 , a 1 , … , x t − 1 , l t − 1 , a t − 1 , x t s_{1 : t}=x_{1}, l_{1}, a_{1}, \ldots, x_{t-1}, l_{t-1}, a_{t-1}, x_{t} s1:t=x1,l1,a1,…,xt−1,lt−1,at−1,xt 根据上一阶段的动作 a t a_{t} at和位置 l t − 1 l_{t-1} lt−1,从输入图像提取出信息,通过模型网络,输出特征信息,利用POMDP决定出下一阶段的动作 a t a_{t} at和位置 l t − 1 l_{t-1} lt−1,设:
则整个过程的回报: R ( s ) = ∑ t = 1 T γ t r ( s 1 : t , a t , s t + 1 ) R(s)=\sum_{t=1}^{T} \gamma^{t} r\left(s_{1 : t}, a_{t}, s_{t+1}\right) R(s)=∑t=1Tγtr(s1:t,at,st+1)
策略参数 θ \theta θ的期望回报为:
J ( θ ) = E p ( s ∣ θ ) [ R ( s ) ] = ∫ p ( s ∣ θ ) R ( s ) d s J(\theta)=E_{p(s | \theta)}[R(s)]=\int p(s | \theta) R(s) \mathrm{d} s J(θ)=Ep(s∣θ)[R(s)]=∫p(s∣θ)R(s)ds p ( s ∣ θ ) = p ( l 0 ) ∏ t = 1 T p ( s t + 1 ∣ s 1 : t , a t ) π ( a t ∣ s 1 : t , θ ) p(s | \theta)=p\left(l_{0}\right) \prod_{t=1}^{T} p\left(s_{t+1} | s_{1 : t}, a_{t}\right) \pi\left(a_{t} | s_{1 : t}, \theta\right) p(s∣θ)=p(l0)t=1∏Tp(st+1∣s1:t,at)π(at∣s1:t,θ) 对于梯度计算,有个log小技巧, ∇ p ( s ∣ θ ) = p ( s ∣ θ ) ∇ log p ( s ∣ θ ) \nabla p(s | \theta)=p(s | \theta) \nabla \log p(s | \theta) ∇p(s∣θ)=p(s∣θ)∇logp(s∣θ)故计算回报的梯度有:
∇ θ J ( θ ) = ∫ ∇ θ p ( s ∣ θ ) R ( s ) d s \nabla_{\theta} J(\theta)=\int \nabla_{\theta} p(s | \theta) R(s) \mathrm{d} s ∇θJ(θ)=∫∇θp(s∣θ)R(s)ds = ∫ p ( s ∣ θ ) ∇ θ log p ( s ∣ θ ) R ( s ) d s =\int p(s | \theta) \nabla_{\theta} \log p(s | \theta) R(s) \mathrm{d} s =∫p(s∣θ)∇θlogp(s∣θ)R(s)ds = ∫ p ( s ∣ θ ) ∑ t = 1 T ∇ θ log π ( a t ∣ s 1 : t ; θ ) R ( s ) d s =\int p(s | \theta) \sum_{t=1}^{T} \nabla_{\theta} \log \pi\left(a_{t} | s_{1 : t} ; \theta\right) R(s) \mathrm{d} s =∫p(s∣θ)t=1∑T∇θlogπ(at∣s1:t;θ)R(s)ds = E p ( s ∣ θ ) [ ∑ t = 1 T ∇ θ log π ( a t ∣ s 1 : t ; θ ) R ( h ) ] =E_{p(s | \theta)}\left[\sum_{t=1}^{T} \nabla_{\theta} \log \pi\left(a_{t} | s_{1 : t} ; \theta\right) R(h)\right] =Ep(s∣θ)[t=1∑T∇θlogπ(at∣s1:t;θ)R(h)]
由于 p ( s ∣ θ ) p(s | \theta) p(s∣θ)未知,故取经验平均求解,即:
∇ θ J ( θ ) ^ = 1 M ∑ i = 1 M ∑ t = 1 T ∇ θ log π ( a t i ∣ s 1 : t i ; θ ) R t i \nabla_{\theta} J \hat{(\theta)}=\frac{1}{M} \sum_{i=1}^{M} \sum_{t=1}^{T} \nabla_{\theta} \log \pi\left(a_{t}^{i} | s_{1 : t}^{i} ; \theta\right) R_{t}^{i} ∇θJ(θ)^=M1i=1∑Mt=1∑T∇θlogπ(ati∣s1:ti;θ)Rti
可以通过减去一个 b t b_{t} bt降低方差,即:
1 M ∑ i = 1 M ∑ t = 1 T ∇ θ log π ( a t i ∣ s 1 : t i ; θ ) ( R t i − b t ) \frac{1}{M} \sum_{i=1}^{M} \sum_{t=1}^{T} \nabla_{\theta} \log \pi\left(a_{t}^{i} | s_{1 : t}^{i} ; \theta\right)\left(R_{t}^{i}-b_{t}\right) M1i=1∑Mt=1∑T∇θlogπ(ati∣s1:ti;θ)(Rti−bt)
b t b_{t} bt可取 E π [ R t ] E_{\pi}\left[R_{t}\right] Eπ[Rt],该算法被称为REINFORCE。
训练神经网络自然想到反向传播,通过REINFORCE得到 f a f_{a} fa和 f l f_{l} fl的梯度信息。然后反向依次训练RNN,Glimpse Network。对于分类问题,由于 a T a_{T} aT是确定,最大化 log π ( a T ∣ s 1 : T ; θ ) \log \pi\left(a_{T} | s_{1 : T} ; \theta\right) logπ(aT∣s1:T;θ),通过优化 f a f_{a} fa的交叉熵得到梯度,反向训练模型。
∇ θ log p ( s ∣ θ ) = ∑ t = 1 T ∇ θ log π ( a t ∣ s 1 : t ; θ ) \nabla_{\theta} \log p(s | \theta)=\sum_{t=1}^{T} \nabla_{\theta} \log \pi\left(a_{t} | s_{1 : t} ; \theta\right) ∇θlogp(s∣θ)=∑t=1T∇θlogπ(at∣s1:t;θ)的推导:
对 p ( s ∣ θ ) = p ( l 0 ) ∏ t = 1 T p ( s t + 1 ∣ s 1 : t , a t ) π ( a t ∣ s 1 : t , θ ) p(s | \theta)=p\left(l_{0}\right) \prod_{t=1}^{T} p\left(s_{t+1} | s_{1 : t}, a_{t}\right) \pi\left(a_{t} | s_{1 : t}, \theta\right) p(s∣θ)=p(l0)∏t=1Tp(st+1∣s1:t,at)π(at∣s1:t,θ)两边求 log \log log,得:
log p ( s ∣ θ ) = log p ( l 0 ) + ∑ t = 1 T log ( p ( s t + 1 ∣ s 1 : t , a t ) π ( a t ∣ s 1 : t , θ ) ) \log p(s | \theta)=\log p(l_{0}) +\sum_{t=1}^{T}\log (p\left(s_{t+1} | s_{1 : t}, a_{t}\right) \pi\left(a_{t} | s_{1 : t}, \theta\right)) logp(s∣θ)=logp(l0)+t=1∑Tlog(p(st+1∣s1:t,at)π(at∣s1:t,θ)) 为了让公式看的更清楚,将 p ( s t + 1 ∣ s 1 : t , a t ) p\left(s_{t+1} | s_{1 : t}, a_{t}\right) p(st+1∣s1:t,at)简写成 p ( s t + 1 ∣ t ) p\left(s_{t+1} | t\right) p(st+1∣t),将 π ( a t ∣ s 1 : t , θ ) \pi\left(a_{t} | s_{1 : t}, \theta\right) π(at∣s1:t,θ)简写成 π ( a t ∣ t , θ ) \pi\left(a_{t} | t, \theta\right) π(at∣t,θ),代入上式得到,同时对两边求梯度 ∇ \nabla ∇,得:
∇ log p ( s ∣ θ ) = ∇ log p ( l 0 ) + ∇ ∑ t = 1 T log ( p ( s t + 1 ∣ t ) π ( a t ∣ t , θ ) ) \nabla\log p(s | \theta)=\nabla\log p(l_{0}) +\nabla\sum_{t=1}^{T}\log (p\left(s_{t+1} | t\right) \pi\left(a_{t} |t, \theta)\right) ∇logp(s∣θ)=∇logp(l0)+∇t=1∑Tlog(p(st+1∣t)π(at∣t,θ)) 注意,这里 log p ( l 0 ) \log p\left(l_{0}\right) logp(l0)是跟t和 θ \theta θ无关的常量,求导后为0,消去 log p ( l 0 ) \log p\left(l_{0}\right) logp(l0)后,得:
∇ log p ( s ∣ θ ) = ∇ ∑ t = 1 T log ( p ( s t + 1 ∣ t ) π ( a t ∣ t , θ ) ) \nabla\log p(s | \theta)=\nabla\sum_{t=1}^{T}\log (p\left(s_{t+1} | t\right) \pi\left(a_{t} |t, \theta\right)) ∇logp(s∣θ)=∇t=1∑Tlog(p(st+1∣t)π(at∣t,θ)) 利用 log ( a b ) = log a + l o g b \log(ab)=\log a + logb log(ab)=loga+logb将上式右边展开得,
∇ log p ( s ∣ θ ) = ∇ ( ∑ t = 1 T ( log ( p ( s t + 1 ∣ t ) + log ( π ( a t ∣ t , θ ) ) ) ) \nabla\log p(s | \theta)=\nabla(\sum_{t=1}^{T}(\log (p\left(s_{t+1} | t\right) + \log(\pi\left(a_{t} |t, \theta\right)))) ∇logp(s∣θ)=∇(t=1∑T(log(p(st+1∣t)+log(π(at∣t,θ)))) 考虑到 log ( p ( s t + 1 ∣ t ) \log (p\left(s_{t+1} | t\right) log(p(st+1∣t)为常数,求导后为0,故可消除,得:
∇ log p ( s ∣ θ ) = ∇ ( ∑ t = 1 T ( log ( π ( a t ∣ t , θ ) ) ) ) \nabla\log p(s | \theta)=\nabla(\sum_{t=1}^{T}( \log(\pi\left(a_{t} |t, \theta\right)))) ∇logp(s∣θ)=∇(t=1∑T(log(π(at∣t,θ)))) 将 ∇ \nabla ∇移到 ∑ \sum ∑里面,再加上 θ \theta θ角标,得:
∇ log p ( s ∣ θ ) = ∑ t = 1 T ( ∇ log ( π ( a t ∣ t , θ ) ) ) \nabla\log p(s | \theta)=\sum_{t=1}^{T}(\nabla \log(\pi\left(a_{t} |t, \theta\right))) ∇logp(s∣θ)=t=1∑T(∇log(π(at∣t,θ))) 最后,将 π ( a t ∣ t , θ ) \pi\left(a_{t} | t, \theta\right) π(at∣t,θ)替换回 π ( a t ∣ s 1 : t , θ ) \pi\left(a_{t} | s_{1 : t}, \theta\right) π(at∣s1:t,θ),得:
∇ log p ( s ∣ θ ) = ∑ t = 1 T ( ∇ log ( π ( a t ∣ s 1 : t , θ ) ) ) \nabla\log p(s | \theta)=\sum_{t=1}^{T}(\nabla \log(\pi\left(a_{t} | s_{1 : t}, \theta\right))) ∇logp(s∣θ)=t=1∑T(∇log(π(at∣s1:t,θ))) 至此,完成推导。
这里的主要作用是为了减少方差,VR(Variance Reduction)方法以及其中的baseline都是增强学习中的基本设置。更多理解,待后续补上强化学习的知识之后,再来分析。
请参考【增强学习】Recurrent Visual Attention源码解读:结合torch代码,解读RAM的网络结构