NEURAL ARCHITECTURE SEARCH WITH REINFORCEMENT LEARNING
神经网络搜索与强化学习
NAS:Neural Architecture Search 神经网络搜索(不用人力来设计网络 用AI来搜索AI 炼丹 烧钱的魔法);
【搜索出的架构并不重要,重要的是搜索方法和思想】
发表时间:[Submitted on 5 Nov 2016 (v1), last revised 15 Feb 2017 (this version, v2)]
发表期刊/会议:ICLR 2017;
论文地址:https://arxiv.org/abs/1611.01578v2;
代码地址:https://github.com/tensorflow/models;
系列论文阅读顺序:
神经网络难以设计;
本文使用RNN作为controller来生成神经网络模型的描述,并用强化学习来训练RNN,以最大限度地提高在验证集上模型的准确性,从而生成高性能架构;
如下图1所示,RNN作为controller生成网络架构A,以此生成的网络作为基础训练子网络,在相应的数据集上进行训练,计算精度R,计算梯度并最大化R进行反向传播更新梯度,作为奖励(reward)返回给RNN,以此来迭代优化(详细内容见强化学习);
如图2所示,RNN生成tokens(就是一个向量)作为模型的超参数;
随便举个例子…
t o k e n = [ 3 , 3 , 1 , 1 , 5 ] token = [3,3,1,1,5] token=[3,3,1,1,5] 代表Layer N有,filter height = 3,filter width = 3,stride height = 1,stride width = 1,有5个filter;
生成的token对应唯一一个模型,建立这个模型并在数据集上进行训练;收敛后,记录验证集上的精度R,以此为奖励更新RNN的参数 θ c θ_c θc,在下一轮迭代中, θ c θ_c θc指导RNN生成更好的CNN;
a 1 : T a_{1:T} a1:T代表RNN每部分的操作,比如1代表filter height,2代表filter width,…,T代表number of filters;
目标函数 J ( θ c ) J(θ_c) J(θc):
采用【Simple statistical gradient-following algorithms for connectionist reinforcement learning. 】一文的强化学习优化方法,将参数R设置为可求导的(不然无法更新);
目标函数求导:
上式根据经验简化:
m:RNN(controller)在同一batch中采样的不同架构的数量;
T:生成不同网络架构所需的超参数数量;
R k R_k Rk:第k个网络架构的精度;
但是此式有很高的方差,为了减少估计的方差,采用了基线函数:
b:之前架构的指数移动平均值;
使用并行训练和异步更新加速训练
训练子网络所需的时间可能非常长,使用分布式训练和异步更新来控制controller的学习过程;
如图3所示;
本节将介绍一种方法,允许模型有残差连接,从而扩大搜索空间。
使用了一种集选择(set-selection)类型的注意力,在Layer N,添加一个锚点(anchor point),和前面的Layer N-1的内容连接;
h j h_j hj:第j层隐层的状态, j ∈ [ 0 , N − 1 ] j∈[0,N-1] j∈[0,N−1];
W p r e v W_{prev} Wprev, W c u r r W_{curr} Wcurr, v T v^T vT:可学习的参数;
通过一个sigmoid激活函数和tanh激活函数联系当前结点和之前结点的信息;
残差连接可能导致“编译失败”,因为其中一层与另一层不兼容,或者一层可能没有任何输入或输出。 为了避免这些问题,本文采用了三种简单的技术:
3.1节中不预测学习率,其它类型的层,实际上还可以预测学习率、池化层、batch norm等等;
搜索出的架构见图7(附录);
上面提到的都是生成普通的CNN,本节将模型改造,使得可以生成类似LSTM的训练单元结构;
如图5左所示,树具有三个节点,两个叶子节点0和1,一个内部节点2,输入 h t − 1 h_{t-1} ht−1和 x t x_t xt:
计算过程,如下图所示:
在图5的例子中只有两个叶子节点,称为"base 2"架构,实际试验中,用"base 8";
数据集:带有预处理和数据增强的CIFAR-10数据集,图像大小,32 * 32;
搜索空间:卷积结构,激活函数,归一化,残差连接;
训练细节:使用adam优化器优化,学习率策略为0.0006。权重初始化符合[-0.08, 0.08]的均匀分布。在分布式训练部分,设置参数服务器的 S 为20,拷贝数量K 为100,每个部分子网络m的数量为8;
从12,800个架构中,选出最优的结果如下:
数据集:Penn Treebank数据集(著名的语言建模baseline);
搜索空间:如3.4节所述;
结果见表2:
迁移学习结果见表3: