【深度学习】聚焦机制DRAM(Deep Recurrent Attention Model)算法详解

Ba, Jimmy, Volodymyr Mnih, and Koray Kavukcuoglu. “Multiple object recognition with visual attention.” arXiv preprint arXiv:1412.7755 (2014).

思想

三位作者均来自于风头正劲的Google DeepMind,三作Koray Kavukcuoglu在AlphaGo的Nature论文中榜上有名。

本文执行的任务相对简单:从图片中识别长度、位置未知的手写数字串。但包含了当今神经网络的诸多热点方向,包括:

  • 聚焦机制(Attention):每次只看输入的一小部分,诸次移动观察范围。
  • 循环神经网络(Recurrent NN):在每一次移动和输出之间建立记忆
  • 增强学习(Reinforcement learning):在训练过程中,根据不可导的反馈,从当前位置产生探索性的采样。

本文和前一篇文章中介绍的RAM(Recurrent Visual Attention Model)算法极为相似,但是更侧重数学推导。建议先阅读这篇博客中的解读。
对于增强学习没概念的同学,也可以参考这篇博客:Torch中的增强学习层

模型

核心数据

X : 输入图像
n : 步骤序号,共有 N 个步骤,每次查看图像一小部分。
ln : 第 n 步查看的图像位置。整数类型xy坐标,图像中心为(0,0),图像边缘对应的坐标为系统超参数,决定搜索粒度。
xn : 第 n 步观察到的图像内容,称为glimpse。是以 ln 为中心,尺寸相同,缩放和范围等差的图像金字塔。
【深度学习】聚焦机制DRAM(Deep Recurrent Attention Model)算法详解_第1张图片
特别要注意的是: xn 没法对 ln 求导。

子网络

整个系统由若干部分组成,执行不同功能。系统的组成部件都称为网络。

系统中变量繁多,不必急于看全图,顺序推导即可。

Glimpse网络

输入:当前位置 ln ,当前图像块 xn
输出:当前观察的信息 gn

形式

gn=Gimage(xn|Wimage)Gloc(ln|Wloc)

Gimage Gloc 是两个网络,其参数为 Wimage Wloc 。分别把图像(what)和位置(where)编码成统一维度的信息,进行点乘。

作用:通过小范围观测,提取纹理和位置信息。

条件号后面的 W 表示某网络参数,此后不再赘述。

Recurrent网络

输入:当前观察信息 gn ,上一步状态 r1n1,r2n1
输出:当前的两个循环状态 r1n,r2n

形式

r1n=Rrecur(gn,r1n1|Wr1)

r2n=Rrecur(r1n,r2n1|Wr2)

两个状态使用相同的网络 Rrecur 进行估计,只是输入不同。由于存在两层循环状态,所以本文算法称为Deep RAM。

作用:通过小范围观测,更新网络循环状态

Emission网络

输入:当前第二级循环状态 r2n
输出:下一步建议的观察位置 ln+1

形式

l^n+1=E(r2n|We)

注意,这个给出的 l^ 是一个“建议”,下一步的真正位置可能围绕这个建议有所偏差。

作用:利用系统循环状态,决定观测位置。

在RAM算法中,这部分称为locator。

Classification网络

输入:当前第一级循环状态 r1n
输出:类标 y

形式

P(y|I)=O(r1n|Wo)

出现概率P的原因是:网络输出是一个softmax层。
不一定每一步都有输出,可以设定每 K 步输出一个类标,即 K 次观察能够决定一个字母。

作用:从系统循环状态估计分类结果。

在RAM算法中,这部分称为Agent。

Context网络

输入:缩小后的原始图像 Icoarse
输出:第二级循环初始状态 r20

形式

r20=C(Icoarse|Wc)

系统的初始状态决定了开始观察的位置,这需要从整张图像上进行推断。
另一方面,第一级循环状态初始值化: r10=0

作用:全局查看,初始化循环状态。

整体结构

把上述网络连起来:
【深度学习】聚焦机制DRAM(Deep Recurrent Attention Model)算法详解_第2张图片
整个系统表示为: P(y1,y2..yS|I,W) ,即输入一张图像 I ,网络能够估计输出一系列标定 yn 的概率。

总结一下各个子网络:Glimpse(红色),Recurrent(黄色),Emission(绿色),Classification(蓝色),Context(紫色)。

需要特殊说明的是,下一次观察位置 l^n+1 和分类结果 yn 分别连接到不同的循环变量上(绿/蓝)。这样做的目的是,人为地解除“看哪儿”和“是什么”之间的耦合,避免投机取巧,从位置推断内容。

训练

代价函数

首先考虑一步观测( N=1 )情况,整个网络的输出为 p(y|I,W) ,为书写简洁,省去参数记为 p(y|I)

训练的目标是:找到参数 W ,最大化对数似然函数 logL(W)=logp(y|I) 。这个概率是在训练集上的取值。

直接求导 p(y|I)/W 有一个困难:图像块 x 由观察位置 l 决定,但 x/l 很难求导。需要进行分解

logp(y|I)=log[lp(y,l|I)]=log[lp(l|I)p(y|l,I)]

整个估计过程分解成一系列求和:首先从图像估计观测位置 p(l|I) ,再从图像和观测位置估计标定 p(y|l,I)

下界

上式中log里包含两个乘积项,还是不好求导。利用log的凹函数性质进行转换1:

log[lp(l|I)p(y|l,I)]lp(l|I)logp(y|l,I)

右侧是原代价函数的下界,记为 F ,这个式子也称为系统的variational free energy。训练目标变为:找到能够最大化右式的参数 W 。对参数 W 求导,寻找梯度下降方向:

FW=l[p(l|I)logp(y|l,I)W+logp(y|l,I)p(l|I)W]

利用 logx/x=1/x p(l|I) 提取出来:

FW=lp(l|I)[logp(y|l,I)W+logp(y|l,I)1p(l|I)p(l|I)W]

FW=lp(l|I)[logp(y|l,I)W+logp(y|l,I)logp(l|I)W]

到此略作休息,看一看每一项的物理意义。

方括号内部是两个概率对于网络参数的导数。其中, p(y|l,I) 是网络的输出,相当于上图中红-黄-蓝箭头对应的部分。
p(l|I) 就比较麻烦了,网络中最接近的东西是红-黄-绿箭头对应的部分,但是它只能从图像预测一个最可能的位置 l^ ,无法求出任意位置的概率,还需要另做变换。

增强学习

上式可以看做 p(l|I) 概率对 l 进行采样,之后计算方括号中的梯度之和。
我们近似认为, p(l|I) 服从以 p(l^|I) 为中心,方差为 Σ 的高斯分布。

于是上式转化为

FW1Mm=1M[logp(y|lm,I)W+logp(y|lm,I)logp(lm|I)W]

lmN(l^;Σ)

M 是采样个数, Σ 是采样方差。这两个超参数平衡了强化学习中的探索(exploration)与利用(exploitation)。

这里体现了增强学习(reinforcement learning)的思想:以当前策略( l^ )为基础,多次尝试( lm ),用获得的结果更新已有策略( F/W )

梯度由两部分组成:已知图像 I +位置 l ,预测类标 y 带来的误差;以及已知图像 I ,预测位置 l 带来的误差。

第二部分的权重是个对数概率 logp(y|lm,I) 。当建议位置 l^ 不准时,类标预测 y 不准, p(y|lm,I) 很小,导致这一项变成绝对值很大的负数。这种极端的梯度对优化参数不利

Variance Reduction

用下式取代前述对数概率:

logp(y|lm,I)Rmbn

其中, Rm 表示第 m 个采样预测当前标定 y 的质量:

Rm={1  y=argmaxylogp(y|lm,I)0  others

当采样 m 预测出的类标和真值标定 y 相同时, Rm 置为1,否则置为0。

bn 表示当前步骤下,采样质量的baseline。由第 n 步观察的二级循环状态 r2n 估计得到:

bn=Ebaseline(r2n|Wbaseline)

VR(Variance Reduction)方法以及其中的baseline都是增强学习中的基本设置。

总结

至此,每一步迭代中参数更新如下, λ 调节定位和识别两个任务间的权重:

FW1Mm=1M[logp(y|lm,I)W+λ(Rmb)logp(lm|I)W]

在正向传播时,下图抽象地表示了 I,l,y 三者的关系:
【深度学习】聚焦机制DRAM(Deep Recurrent Attention Model)算法详解_第3张图片

在反向传播时, p(y|l,I) 相关参数由监督学习训练(绿线), p(l|I) 相关参数由增强学习训练(红线)。
【深度学习】聚焦机制DRAM(Deep Recurrent Attention Model)算法详解_第4张图片


  1. 简单例子说明: log(a1b1+a2b2)a1logb1+a2logb2

你可能感兴趣的:(论文解读)